use std::collections::BTreeMap;
use std::fs;
use std::io;
use std::path::PathBuf;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use sha2::{Digest, Sha256};
use crate::config::OAuthConfig;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct OAuthTokenSet {
pub access_token: String,
pub refresh_token: Option<String>,
pub expires_at: Option<u64>,
pub scopes: Vec<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ProviderConfig {
pub api_key: Option<String>,
pub model: Option<String>,
pub base_url: Option<String>,
}
pub fn credentials_path() -> io::Result<PathBuf> {
Ok(credentials_home_dir()?.join("credentials.json"))
}
pub fn load_provider_config(provider: &str) -> io::Result<Option<ProviderConfig>> {
let path = credentials_path()?;
let root = read_credentials_root(&path)?;
let Some(providers) = root.get("providers") else {
return Ok(None);
};
let Some(config) = providers.get(provider) else {
return Ok(None);
};
serde_json::from_value::<ProviderConfig>(config.clone())
.map(Some)
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))
}
pub fn save_provider_config(provider: &str, config: ProviderConfig) -> io::Result<()> {
let path = credentials_path()?;
let mut root = read_credentials_root(&path)?;
let mut providers = root.get("providers")
.and_then(|v| v.as_object())
.cloned()
.unwrap_or_default();
providers.insert(
provider.to_string(),
serde_json::to_value(config).map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?,
);
root.insert("providers".to_string(), Value::Object(providers));
write_credentials_root(&path, &root)
}
pub fn load_oauth_credentials() -> io::Result<Option<OAuthTokenSet>> {
let path = credentials_path()?;
let root = read_credentials_root(&path)?;
let Some(oauth) = root.get("oauth") else {
return Ok(None);
};
serde_json::from_value::<OAuthTokenSet>(oauth.clone())
.map(Some)
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))
}
pub fn save_oauth_credentials(token_set: &OAuthTokenSet) -> io::Result<()> {
let path = credentials_path()?;
let mut root = read_credentials_root(&path)?;
root.insert(
"oauth".to_string(),
serde_json::to_value(token_set).map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?,
);
write_credentials_root(&path, &root)
}
pub fn clear_oauth_credentials() -> io::Result<()> {
let path = credentials_path()?;
let mut root = read_credentials_root(&path)?;
root.remove("oauth");
write_credentials_root(&path, &root)
}
fn credentials_home_dir() -> io::Result<PathBuf> {
if let Some(path) = std::env::var_os("TERNLANG_CONFIG_HOME") {
return Ok(PathBuf::from(path));
}
let home = std::env::var_os("HOME")
.ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "HOME is not set"))?;
Ok(PathBuf::from(home).join(".ternlang"))
}
fn read_credentials_root(path: &PathBuf) -> io::Result<Map<String, Value>> {
match fs::read_to_string(path) {
Ok(contents) => {
if contents.trim().is_empty() {
return Ok(Map::new());
}
serde_json::from_str::<Value>(&contents)
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?
.as_object()
.cloned()
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
"credentials file must contain a JSON object",
)
})
}
Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(Map::new()),
Err(error) => Err(error),
}
}
fn write_credentials_root(path: &PathBuf, root: &Map<String, Value>) -> io::Result<()> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
let rendered = serde_json::to_string_pretty(&Value::Object(root.clone()))
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
let temp_path = path.with_extension("json.tmp");
fs::write(&temp_path, format!("{rendered}\n"))?;
fs::rename(temp_path, path)
}
pub fn generate_pkce_pair() -> io::Result<PkceCodePair> { Ok(PkceCodePair { verifier: "".to_string(), challenge: "".to_string(), challenge_method: PkceChallengeMethod::S256 }) }
pub fn generate_state() -> io::Result<String> { Ok("".to_string()) }
pub fn loopback_redirect_uri(port: u16) -> String { format!("http://localhost:{port}/callback") }
pub fn parse_oauth_callback_request_target(_target: &str) -> Result<OAuthCallbackParams, String> { Ok(OAuthCallbackParams { code: None, state: None, error: None, error_description: None }) }
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PkceCodePair { pub verifier: String, pub challenge: String, pub challenge_method: PkceChallengeMethod }
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PkceChallengeMethod { S256 }
impl PkceChallengeMethod { pub const fn as_str(self) -> &'static str { "S256" } }
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OAuthAuthorizationRequest { pub authorize_url: String, pub client_id: String, pub redirect_uri: String, pub scopes: Vec<String>, pub state: String, pub code_challenge: String, pub code_challenge_method: PkceChallengeMethod, pub extra_params: BTreeMap<String, String> }
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OAuthTokenExchangeRequest { pub grant_type: &'static str, pub code: String, pub redirect_uri: String, pub client_id: String, pub code_verifier: String, pub state: String }
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OAuthRefreshRequest { pub grant_type: &'static str, pub refresh_token: String, pub client_id: String, pub scopes: Vec<String> }
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OAuthCallbackParams { pub code: Option<String>, pub state: Option<String>, pub error: Option<String>, pub error_description: Option<String> }
impl OAuthAuthorizationRequest { pub fn from_config(_config: &OAuthConfig, _redirect_uri: String, _state: String, _pkce: &PkceCodePair) -> Self { Self { authorize_url: "".to_string(), client_id: "".to_string(), redirect_uri: "".to_string(), scopes: vec![], state: "".to_string(), code_challenge: "".to_string(), code_challenge_method: PkceChallengeMethod::S256, extra_params: BTreeMap::new() } } pub fn build_url(&self) -> String { "".to_string() } }
impl OAuthTokenExchangeRequest { pub fn from_config(_config: &OAuthConfig, _code: String, _state: String, _verifier: String, _redirect_uri: String) -> Self { Self { grant_type: "authorization_code", code: "".to_string(), redirect_uri: "".to_string(), client_id: "".to_string(), code_verifier: "".to_string(), state: "".to_string() } } }
pub fn code_challenge_s256(verifier: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes());
let hash = hasher.finalize();
hash.iter().map(|b| format!("{b:02x}")).collect()
}
pub fn parse_oauth_callback_query(_query: &str) -> Result<OAuthCallbackParams, String> {
Ok(OAuthCallbackParams { code: None, state: None, error: None, error_description: None })
}