use std::path::PathBuf;
use std::sync::Arc;
use crate::llm::oauth_helpers::OAUTH_CALLBACK_PORT;
use chrono::{DateTime, Utc};
use reqwest::Client;
use secrecy::SecretString;
use serde::{Deserialize, Serialize};
use tokio::sync::{Mutex, RwLock};
use crate::llm::error::LlmError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionData {
pub session_token: String,
pub created_at: DateTime<Utc>,
#[serde(default)]
pub auth_provider: Option<String>,
}
#[derive(Debug, Clone)]
pub struct SessionConfig {
pub auth_base_url: String,
pub session_path: PathBuf,
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
auth_base_url: "https://private.near.ai".to_string(),
session_path: PathBuf::from("session.json"),
}
}
}
pub struct SessionManager {
config: SessionConfig,
client: Client,
token: RwLock<Option<SecretString>>,
renewal_lock: Mutex<()>,
store: RwLock<Option<Arc<dyn crate::db::Database>>>,
user_id: RwLock<String>,
}
impl SessionManager {
pub fn new(config: SessionConfig) -> Self {
let manager = Self {
config,
client: Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.unwrap_or_else(|_| Client::new()),
token: RwLock::new(None),
renewal_lock: Mutex::new(()),
store: RwLock::new(None),
user_id: RwLock::new("default".to_string()),
};
if let Ok(data) = std::fs::read_to_string(&manager.config.session_path)
&& let Ok(session) = serde_json::from_str::<SessionData>(&data)
{
if let Ok(mut guard) = manager.token.try_write() {
*guard = Some(SecretString::from(session.session_token));
tracing::info!(
"Loaded session token from {}",
manager.config.session_path.display()
);
}
}
manager
}
pub async fn new_async(config: SessionConfig) -> Self {
let manager = Self {
config,
client: Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.unwrap_or_else(|_| Client::new()),
token: RwLock::new(None),
renewal_lock: Mutex::new(()),
store: RwLock::new(None),
user_id: RwLock::new("default".to_string()),
};
if let Err(e) = manager.load_session().await {
tracing::debug!("No existing session found: {}", e);
}
manager
}
pub async fn attach_store(&self, store: Arc<dyn crate::db::Database>, user_id: &str) {
*self.store.write().await = Some(store);
*self.user_id.write().await = user_id.to_string();
if let Err(e) = self.load_session_from_db().await {
tracing::debug!("No session in DB: {}", e);
}
}
pub async fn get_token(&self) -> Result<SecretString, LlmError> {
let guard = self.token.read().await;
guard.clone().ok_or_else(|| LlmError::AuthFailed {
provider: "nearai".to_string(),
})
}
pub async fn has_token(&self) -> bool {
self.token.read().await.is_some()
}
pub async fn ensure_authenticated(&self) -> Result<(), LlmError> {
if !self.has_token().await {
return self.initiate_login().await;
}
tracing::debug!("Validating session...");
match self.validate_token().await {
Ok(()) => {
tracing::debug!("Session valid");
Ok(())
}
Err(e) => {
tracing::info!("Session expired or invalid: {}", e);
self.initiate_login().await
}
}
}
async fn validate_token(&self) -> Result<(), LlmError> {
use secrecy::ExposeSecret;
let token = self.get_token().await?;
let url = format!("{}/v1/users/me", self.config.auth_base_url);
let response = self
.client
.get(&url)
.header("Authorization", format!("Bearer {}", token.expose_secret()))
.send()
.await
.map_err(|e| LlmError::SessionRenewalFailed {
provider: "nearai".to_string(),
reason: format!("Validation request failed: {}", e),
})?;
if response.status().is_success() {
return Ok(());
}
if response.status().as_u16() == 401 {
return Err(LlmError::SessionExpired {
provider: "nearai".to_string(),
});
}
let status = response.status();
let body = response.text().await.unwrap_or_default();
let preview = crate::agent::truncate_for_preview(&body, 200);
Err(LlmError::SessionRenewalFailed {
provider: "nearai".to_string(),
reason: format!("Validation failed: HTTP {status}: {preview}"),
})
}
pub async fn handle_auth_failure(&self) -> Result<(), LlmError> {
let _guard = self.renewal_lock.lock().await;
tracing::info!("Session expired or invalid, re-authenticating...");
self.initiate_login().await
}
async fn initiate_login(&self) -> Result<(), LlmError> {
use crate::llm::oauth_helpers;
let cb_url = oauth_helpers::callback_url();
let host = oauth_helpers::callback_host();
println!();
println!("╔════════════════════════════════════════════════════════════════╗");
println!("║ NEAR AI Authentication ║");
println!("╠════════════════════════════════════════════════════════════════╣");
println!("║ Choose an authentication method: ║");
println!("║ ║");
println!("║ [1] GitHub (requires localhost browser access) ║");
println!("║ [2] Google (requires localhost browser access) ║");
println!("║ [3] NEAR Wallet (coming soon) ║");
println!("║ [4] NEAR AI Cloud API key ║");
println!("║ ║");
println!("╚════════════════════════════════════════════════════════════════╝");
println!();
print!("Enter choice [1-4]: ");
use std::io::Write;
std::io::stdout().flush().ok();
let mut choice = String::new();
std::io::stdin()
.read_line(&mut choice)
.map_err(|e| LlmError::SessionRenewalFailed {
provider: "nearai".to_string(),
reason: format!("Failed to read input: {}", e),
})?;
match choice.trim() {
"4" => return self.api_key_login().await,
"3" => {
println!();
println!("NEAR Wallet authentication is not yet implemented.");
println!("Please use GitHub or Google for now.");
return Err(LlmError::SessionRenewalFailed {
provider: "nearai".to_string(),
reason: "NEAR Wallet auth not yet implemented".to_string(),
});
}
"1" | "" | "2" => {} other => {
return Err(LlmError::SessionRenewalFailed {
provider: "nearai".to_string(),
reason: format!("Invalid choice: {}", other),
});
}
}
if !oauth_helpers::is_loopback_host(&host) {
println!();
println!("Warning: OAuth callback is using plain HTTP to a remote host ({host}).");
println!(" The session token will be transmitted unencrypted.");
println!(" Consider SSH port forwarding instead:");
println!(
" ssh -L {OAUTH_CALLBACK_PORT}:127.0.0.1:{OAUTH_CALLBACK_PORT} user@{host}"
);
}
let listener = oauth_helpers::bind_callback_listener().await.map_err(|e| {
LlmError::SessionRenewalFailed {
provider: "nearai".to_string(),
reason: e.to_string(),
}
})?;
let (auth_provider, auth_url) = match choice.trim() {
"2" => {
let url = format!(
"{}/v1/auth/google?frontend_callback={}",
self.config.auth_base_url,
urlencoding::encode(&cb_url)
);
("google", url)
}
_ => {
let url = format!(
"{}/v1/auth/github?frontend_callback={}",
self.config.auth_base_url,
urlencoding::encode(&cb_url)
);
("github", url)
}
};
println!();
println!("Opening {} authentication...", auth_provider);
println!();
println!(" {}", auth_url);
println!();
if let Err(e) = open::that(&auth_url) {
tracing::debug!("Could not open browser automatically: {}", e);
println!("(Could not open browser automatically, please copy the URL above)");
} else {
println!("(Opening browser...)");
}
println!();
println!("Waiting for authentication...");
let session_token =
oauth_helpers::wait_for_callback(listener, "/auth/callback", "token", "NEAR AI", None)
.await
.map_err(|e| LlmError::SessionRenewalFailed {
provider: "nearai".to_string(),
reason: e.to_string(),
})?;
let auth_provider = Some(auth_provider.to_string());
self.save_session(&session_token, auth_provider.as_deref())
.await?;
{
let mut guard = self.token.write().await;
*guard = Some(SecretString::from(session_token));
}
println!();
println!("✓ Authentication successful!");
println!();
Ok(())
}
async fn api_key_login(&self) -> Result<(), LlmError> {
println!();
println!("NEAR AI Cloud API key");
println!("─────────────────────");
println!();
println!(" 1. Open https://cloud.near.ai in your browser");
println!(" 2. Sign in and navigate to API Keys");
println!(" 3. Create or copy an existing API key");
println!();
let key_secret =
crate::setup::secret_input("API key").map_err(|e| LlmError::SessionRenewalFailed {
provider: "nearai".to_string(),
reason: format!("Failed to read input: {}", e),
})?;
use secrecy::ExposeSecret;
let key = key_secret.expose_secret().to_string();
if key.is_empty() {
return Err(LlmError::SessionRenewalFailed {
provider: "nearai".to_string(),
reason: "API key cannot be empty".to_string(),
});
}
crate::config::helpers::set_runtime_env("NEARAI_API_KEY", &key);
if let Err(e) = crate::bootstrap::upsert_bootstrap_var("NEARAI_API_KEY", &key) {
tracing::warn!("Failed to save API key to bootstrap .env: {}", e);
}
println!();
crate::setup::print_success("NEAR AI Cloud API key saved.");
println!();
Ok(())
}
async fn save_session(&self, token: &str, auth_provider: Option<&str>) -> Result<(), LlmError> {
let session = SessionData {
session_token: token.to_string(),
created_at: Utc::now(),
auth_provider: auth_provider.map(String::from),
};
if let Some(parent) = self.config.session_path.parent() {
tokio::fs::create_dir_all(parent).await.map_err(|e| {
LlmError::Io(std::io::Error::new(
e.kind(),
format!("Failed to create session directory: {}", e),
))
})?;
}
let json =
serde_json::to_string_pretty(&session).map_err(|e| LlmError::SessionRenewalFailed {
provider: "nearai".to_string(),
reason: format!("Failed to serialize session: {}", e),
})?;
tokio::fs::write(&self.config.session_path, json)
.await
.map_err(|e| {
LlmError::Io(std::io::Error::new(
e.kind(),
format!(
"Failed to write session file {}: {}",
self.config.session_path.display(),
e
),
))
})?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let perms = std::fs::Permissions::from_mode(0o600);
tokio::fs::set_permissions(&self.config.session_path, perms)
.await
.map_err(|e| {
LlmError::Io(std::io::Error::new(
e.kind(),
format!(
"Failed to set permissions on {}: {}",
self.config.session_path.display(),
e
),
))
})?;
}
tracing::debug!("Session saved to {}", self.config.session_path.display());
if let Some(ref store) = *self.store.read().await {
let user_id = self.user_id.read().await.clone();
let session_json = serde_json::to_value(&session)
.unwrap_or(serde_json::Value::String(token.to_string()));
if let Err(e) = store
.set_setting(&user_id, "nearai.session_token", &session_json)
.await
{
tracing::warn!("Failed to save session to DB: {}", e);
} else {
tracing::debug!("Session also saved to DB settings");
}
}
Ok(())
}
async fn load_session_from_db(&self) -> Result<(), LlmError> {
let store_guard = self.store.read().await;
let store = store_guard
.as_ref()
.ok_or_else(|| LlmError::SessionRenewalFailed {
provider: "nearai".to_string(),
reason: "No DB store attached".to_string(),
})?;
let user_id = self.user_id.read().await.clone();
let value = if let Some(value) = store
.get_setting(&user_id, "nearai.session_token")
.await
.map_err(|e| LlmError::SessionRenewalFailed {
provider: "nearai".to_string(),
reason: format!("DB query failed: {}", e),
})? {
value
} else {
let legacy = store
.get_setting(&user_id, "nearai.session")
.await
.map_err(|e| LlmError::SessionRenewalFailed {
provider: "nearai".to_string(),
reason: format!("DB query failed: {}", e),
})?;
match legacy {
Some(value) => {
tracing::warn!(
"nearai.session_token missing; falling back to legacy nearai.session for backwards compatibility"
);
value
}
None => {
return Err(LlmError::SessionRenewalFailed {
provider: "nearai".to_string(),
reason: "No session in DB".to_string(),
});
}
}
};
let session: SessionData =
serde_json::from_value(value).map_err(|e| LlmError::SessionRenewalFailed {
provider: "nearai".to_string(),
reason: format!("Failed to parse DB session: {}", e),
})?;
let mut guard = self.token.write().await;
*guard = Some(SecretString::from(session.session_token));
tracing::info!("Loaded session from DB settings");
Ok(())
}
async fn load_session(&self) -> Result<(), LlmError> {
let data = tokio::fs::read_to_string(&self.config.session_path)
.await
.map_err(|e| {
LlmError::Io(std::io::Error::new(
e.kind(),
format!(
"Failed to read session file {}: {}",
self.config.session_path.display(),
e
),
))
})?;
let session: SessionData =
serde_json::from_str(&data).map_err(|e| LlmError::SessionRenewalFailed {
provider: "nearai".to_string(),
reason: format!("Failed to parse session file: {}", e),
})?;
{
let mut guard = self.token.write().await;
*guard = Some(SecretString::from(session.session_token));
}
tracing::info!(
"Loaded session from {} (created: {})",
self.config.session_path.display(),
session.created_at
);
Ok(())
}
pub async fn set_token(&self, token: SecretString) {
let mut guard = self.token.write().await;
*guard = Some(token);
}
}
pub async fn create_session_manager(config: SessionConfig) -> Arc<SessionManager> {
let manager = SessionManager::new_async(config).await;
if let Ok(token) = std::env::var("NEARAI_SESSION_TOKEN")
&& !token.is_empty()
{
tracing::info!("Using session token from NEARAI_SESSION_TOKEN env var");
manager.set_token(SecretString::from(token)).await;
}
Arc::new(manager)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::testing::credentials::{
TEST_SESSION_NEARAI_ABC, TEST_SESSION_NEARAI_XYZ, TEST_SESSION_TOKEN,
};
use secrecy::ExposeSecret;
use tempfile::tempdir;
#[tokio::test]
async fn test_session_save_load() {
let dir = tempdir().unwrap();
let session_path = dir.path().join("session.json");
let config = SessionConfig {
auth_base_url: "https://example.com".to_string(),
session_path: session_path.clone(),
};
let manager = SessionManager::new_async(config.clone()).await;
assert!(!manager.has_token().await);
manager
.save_session(TEST_SESSION_TOKEN, Some("near"))
.await
.unwrap();
manager
.set_token(SecretString::from(TEST_SESSION_TOKEN))
.await;
assert!(manager.has_token().await);
let token = manager.get_token().await.unwrap();
assert_eq!(token.expose_secret(), TEST_SESSION_TOKEN);
let manager2 = SessionManager::new_async(config).await;
assert!(manager2.has_token().await);
let token2 = manager2.get_token().await.unwrap();
assert_eq!(token2.expose_secret(), TEST_SESSION_TOKEN);
let data: SessionData =
serde_json::from_str(&std::fs::read_to_string(&session_path).unwrap()).unwrap();
assert_eq!(data.session_token, TEST_SESSION_TOKEN);
assert_eq!(data.auth_provider, Some("near".to_string()));
}
#[tokio::test]
async fn test_get_token_without_auth_fails() {
let dir = tempdir().unwrap();
let config = SessionConfig {
auth_base_url: "https://example.com".to_string(),
session_path: dir.path().join("nonexistent.json"),
};
let manager = SessionManager::new_async(config).await;
let result = manager.get_token().await;
assert!(result.is_err());
assert!(matches!(result, Err(LlmError::AuthFailed { .. })));
}
#[test]
fn test_session_data_serde_roundtrip_with_auth_provider() {
let original = SessionData {
session_token: TEST_SESSION_NEARAI_ABC.to_string(),
created_at: Utc::now(),
auth_provider: Some("github".to_string()),
};
let json = serde_json::to_string(&original).unwrap();
let deserialized: SessionData = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.session_token, original.session_token);
assert_eq!(deserialized.auth_provider, Some("github".to_string()));
assert_eq!(deserialized.created_at, original.created_at);
}
#[test]
fn test_session_data_serde_roundtrip_without_auth_provider() {
let original = SessionData {
session_token: TEST_SESSION_NEARAI_XYZ.to_string(),
created_at: Utc::now(),
auth_provider: None,
};
let json = serde_json::to_string(&original).unwrap();
let deserialized: SessionData = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.session_token, original.session_token);
assert_eq!(deserialized.auth_provider, None);
}
#[test]
fn test_session_data_missing_auth_provider_defaults_to_none() {
let json = r#"{"session_token":"tok_legacy","created_at":"2025-01-01T00:00:00Z"}"#;
let data: SessionData = serde_json::from_str(json).unwrap();
assert_eq!(data.session_token, "tok_legacy");
assert_eq!(data.auth_provider, None);
}
#[test]
fn test_session_config_default() {
let config = SessionConfig::default();
assert_eq!(config.auth_base_url, "https://private.near.ai");
assert!(config.session_path.ends_with("session.json"));
}
#[tokio::test]
async fn test_new_with_nonexistent_session_file() {
let dir = tempdir().unwrap();
let config = SessionConfig {
auth_base_url: "https://example.com".to_string(),
session_path: dir.path().join("does_not_exist.json"),
};
let manager = SessionManager::new(config);
assert!(!manager.has_token().await);
}
#[tokio::test]
async fn test_set_token_get_token_roundtrip() {
let dir = tempdir().unwrap();
let config = SessionConfig {
auth_base_url: "https://example.com".to_string(),
session_path: dir.path().join("session.json"),
};
let manager = SessionManager::new(config);
manager
.set_token(SecretString::from("my_secret_token"))
.await;
let token = manager.get_token().await.unwrap();
assert_eq!(token.expose_secret(), "my_secret_token");
}
#[tokio::test]
async fn test_has_token_false_then_true() {
let dir = tempdir().unwrap();
let config = SessionConfig {
auth_base_url: "https://example.com".to_string(),
session_path: dir.path().join("session.json"),
};
let manager = SessionManager::new(config);
assert!(!manager.has_token().await);
manager.set_token(SecretString::from("tok_something")).await;
assert!(manager.has_token().await);
}
#[tokio::test]
async fn test_save_session_then_load_in_new_manager() {
let dir = tempdir().unwrap();
let session_path = dir.path().join("session.json");
let config = SessionConfig {
auth_base_url: "https://example.com".to_string(),
session_path: session_path.clone(),
};
let manager = SessionManager::new_async(config.clone()).await;
manager
.save_session("persist_me", Some("google"))
.await
.unwrap();
let manager2 = SessionManager::new_async(config).await;
assert!(manager2.has_token().await);
let token = manager2.get_token().await.unwrap();
assert_eq!(token.expose_secret(), "persist_me");
let raw: SessionData =
serde_json::from_str(&std::fs::read_to_string(&session_path).unwrap()).unwrap();
assert_eq!(raw.auth_provider, Some("google".to_string()));
}
#[tokio::test]
async fn test_save_session_with_no_auth_provider() {
let dir = tempdir().unwrap();
let session_path = dir.path().join("session.json");
let config = SessionConfig {
auth_base_url: "https://example.com".to_string(),
session_path: session_path.clone(),
};
let manager = SessionManager::new_async(config).await;
manager.save_session("anon_tok", None).await.unwrap();
let raw: SessionData =
serde_json::from_str(&std::fs::read_to_string(&session_path).unwrap()).unwrap();
assert_eq!(raw.session_token, "anon_tok");
assert_eq!(raw.auth_provider, None);
}
#[cfg(unix)]
#[tokio::test]
async fn test_session_file_permissions() {
use std::os::unix::fs::PermissionsExt;
let dir = tempdir().unwrap();
let session_path = dir.path().join("session.json");
let config = SessionConfig {
auth_base_url: "https://example.com".to_string(),
session_path: session_path.clone(),
};
let manager = SessionManager::new_async(config).await;
manager
.save_session("secret_tok", Some("github"))
.await
.unwrap();
let metadata = std::fs::metadata(&session_path).unwrap();
let mode = metadata.permissions().mode() & 0o777;
assert_eq!(mode, 0o600, "Session file should have 0600 permissions");
}
}