Skip to main content

albert_runtime/
oauth.rs

1use std::collections::BTreeMap;
2use std::fs;
3use std::io;
4use std::path::PathBuf;
5
6use serde::{Deserialize, Serialize};
7use serde_json::{Map, Value};
8use sha2::{Digest, Sha256};
9
10use crate::config::OAuthConfig;
11
12#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
13pub struct OAuthTokenSet {
14    pub access_token: String,
15    pub refresh_token: Option<String>,
16    pub expires_at: Option<u64>,
17    pub scopes: Vec<String>,
18}
19
20#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
21pub struct ProviderConfig {
22    pub api_key: Option<String>,
23    pub model: Option<String>,
24    pub base_url: Option<String>,
25}
26
27pub fn credentials_path() -> io::Result<PathBuf> {
28    Ok(credentials_home_dir()?.join("credentials.json"))
29}
30
31pub fn load_provider_config(provider: &str) -> io::Result<Option<ProviderConfig>> {
32    let path = credentials_path()?;
33    let root = read_credentials_root(&path)?;
34    let Some(providers) = root.get("providers") else {
35        return Ok(None);
36    };
37    let Some(config) = providers.get(provider) else {
38        return Ok(None);
39    };
40    serde_json::from_value::<ProviderConfig>(config.clone())
41        .map(Some)
42        .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))
43}
44
45pub fn save_provider_config(provider: &str, config: ProviderConfig) -> io::Result<()> {
46    let path = credentials_path()?;
47    let mut root = read_credentials_root(&path)?;
48    let mut providers = root.get("providers")
49        .and_then(|v| v.as_object())
50        .cloned()
51        .unwrap_or_default();
52    
53    providers.insert(
54        provider.to_string(),
55        serde_json::to_value(config).map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?,
56    );
57    
58    root.insert("providers".to_string(), Value::Object(providers));
59    write_credentials_root(&path, &root)
60}
61
62pub fn load_oauth_credentials() -> io::Result<Option<OAuthTokenSet>> {
63    let path = credentials_path()?;
64    let root = read_credentials_root(&path)?;
65    let Some(oauth) = root.get("oauth") else {
66        return Ok(None);
67    };
68    serde_json::from_value::<OAuthTokenSet>(oauth.clone())
69        .map(Some)
70        .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))
71}
72
73pub fn save_oauth_credentials(token_set: &OAuthTokenSet) -> io::Result<()> {
74    let path = credentials_path()?;
75    let mut root = read_credentials_root(&path)?;
76    root.insert(
77        "oauth".to_string(),
78        serde_json::to_value(token_set).map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?,
79    );
80    write_credentials_root(&path, &root)
81}
82
83pub fn clear_oauth_credentials() -> io::Result<()> {
84    let path = credentials_path()?;
85    let mut root = read_credentials_root(&path)?;
86    root.remove("oauth");
87    write_credentials_root(&path, &root)
88}
89
90fn credentials_home_dir() -> io::Result<PathBuf> {
91    if let Some(path) = std::env::var_os("TERNLANG_CONFIG_HOME") {
92        return Ok(PathBuf::from(path));
93    }
94    let home = std::env::var_os("HOME")
95        .ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "HOME is not set"))?;
96    Ok(PathBuf::from(home).join(".ternlang"))
97}
98
99fn read_credentials_root(path: &PathBuf) -> io::Result<Map<String, Value>> {
100    match fs::read_to_string(path) {
101        Ok(contents) => {
102            if contents.trim().is_empty() {
103                return Ok(Map::new());
104            }
105            serde_json::from_str::<Value>(&contents)
106                .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?
107                .as_object()
108                .cloned()
109                .ok_or_else(|| {
110                    io::Error::new(
111                        io::ErrorKind::InvalidData,
112                        "credentials file must contain a JSON object",
113                    )
114                })
115        }
116        Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(Map::new()),
117        Err(error) => Err(error),
118    }
119}
120
121fn write_credentials_root(path: &PathBuf, root: &Map<String, Value>) -> io::Result<()> {
122    if let Some(parent) = path.parent() {
123        fs::create_dir_all(parent)?;
124    }
125    let rendered = serde_json::to_string_pretty(&Value::Object(root.clone()))
126        .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
127    let temp_path = path.with_extension("json.tmp");
128    fs::write(&temp_path, format!("{rendered}\n"))?;
129    fs::rename(temp_path, path)
130}
131
132// Stubs for remaining items to keep the file compilable if needed
133pub fn generate_pkce_pair() -> io::Result<PkceCodePair> { Ok(PkceCodePair { verifier: "".to_string(), challenge: "".to_string(), challenge_method: PkceChallengeMethod::S256 }) }
134pub fn generate_state() -> io::Result<String> { Ok("".to_string()) }
135pub fn loopback_redirect_uri(port: u16) -> String { format!("http://localhost:{port}/callback") }
136pub fn parse_oauth_callback_request_target(_target: &str) -> Result<OAuthCallbackParams, String> { Ok(OAuthCallbackParams { code: None, state: None, error: None, error_description: None }) }
137
138#[derive(Debug, Clone, PartialEq, Eq)]
139pub struct PkceCodePair { pub verifier: String, pub challenge: String, pub challenge_method: PkceChallengeMethod }
140#[derive(Debug, Clone, Copy, PartialEq, Eq)]
141pub enum PkceChallengeMethod { S256 }
142impl PkceChallengeMethod { pub const fn as_str(self) -> &'static str { "S256" } }
143#[derive(Debug, Clone, PartialEq, Eq)]
144pub struct OAuthAuthorizationRequest { pub authorize_url: String, pub client_id: String, pub redirect_uri: String, pub scopes: Vec<String>, pub state: String, pub code_challenge: String, pub code_challenge_method: PkceChallengeMethod, pub extra_params: BTreeMap<String, String> }
145#[derive(Debug, Clone, PartialEq, Eq)]
146pub struct OAuthTokenExchangeRequest { pub grant_type: &'static str, pub code: String, pub redirect_uri: String, pub client_id: String, pub code_verifier: String, pub state: String }
147#[derive(Debug, Clone, PartialEq, Eq)]
148pub struct OAuthRefreshRequest { pub grant_type: &'static str, pub refresh_token: String, pub client_id: String, pub scopes: Vec<String> }
149#[derive(Debug, Clone, PartialEq, Eq)]
150pub struct OAuthCallbackParams { pub code: Option<String>, pub state: Option<String>, pub error: Option<String>, pub error_description: Option<String> }
151impl OAuthAuthorizationRequest { pub fn from_config(_config: &OAuthConfig, _redirect_uri: String, _state: String, _pkce: &PkceCodePair) -> Self { Self { authorize_url: "".to_string(), client_id: "".to_string(), redirect_uri: "".to_string(), scopes: vec![], state: "".to_string(), code_challenge: "".to_string(), code_challenge_method: PkceChallengeMethod::S256, extra_params: BTreeMap::new() } } pub fn build_url(&self) -> String { "".to_string() } }
152impl OAuthTokenExchangeRequest { pub fn from_config(_config: &OAuthConfig, _code: String, _state: String, _verifier: String, _redirect_uri: String) -> Self { Self { grant_type: "authorization_code", code: "".to_string(), redirect_uri: "".to_string(), client_id: "".to_string(), code_verifier: "".to_string(), state: "".to_string() } } }
153
154pub fn code_challenge_s256(verifier: &str) -> String {
155    let mut hasher = Sha256::new();
156    hasher.update(verifier.as_bytes());
157    let hash = hasher.finalize();
158    // Simplified hex for now, should be base64-url-no-padding in production
159    hash.iter().map(|b| format!("{b:02x}")).collect()
160}
161
162pub fn parse_oauth_callback_query(_query: &str) -> Result<OAuthCallbackParams, String> {
163    Ok(OAuthCallbackParams { code: None, state: None, error: None, error_description: None })
164}