1use std::collections::BTreeMap;
2use std::fs;
3use std::io;
4use std::path::PathBuf;
5
6use serde::{Deserialize, Serialize};
7use serde_json::{Map, Value};
8use sha2::{Digest, Sha256};
9
10use crate::config::OAuthConfig;
11
12#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
13pub struct OAuthTokenSet {
14 pub access_token: String,
15 pub refresh_token: Option<String>,
16 pub expires_at: Option<u64>,
17 pub scopes: Vec<String>,
18}
19
20#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
21pub struct ProviderConfig {
22 pub api_key: Option<String>,
23 pub model: Option<String>,
24 pub base_url: Option<String>,
25}
26
27pub fn credentials_path() -> io::Result<PathBuf> {
28 Ok(credentials_home_dir()?.join("credentials.json"))
29}
30
31pub fn load_provider_config(provider: &str) -> io::Result<Option<ProviderConfig>> {
32 let path = credentials_path()?;
33 let root = read_credentials_root(&path)?;
34 let Some(providers) = root.get("providers") else {
35 return Ok(None);
36 };
37 let Some(config) = providers.get(provider) else {
38 return Ok(None);
39 };
40 serde_json::from_value::<ProviderConfig>(config.clone())
41 .map(Some)
42 .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))
43}
44
45pub fn save_provider_config(provider: &str, config: ProviderConfig) -> io::Result<()> {
46 let path = credentials_path()?;
47 let mut root = read_credentials_root(&path)?;
48 let mut providers = root.get("providers")
49 .and_then(|v| v.as_object())
50 .cloned()
51 .unwrap_or_default();
52
53 providers.insert(
54 provider.to_string(),
55 serde_json::to_value(config).map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?,
56 );
57
58 root.insert("providers".to_string(), Value::Object(providers));
59 write_credentials_root(&path, &root)
60}
61
62pub fn load_oauth_credentials() -> io::Result<Option<OAuthTokenSet>> {
63 let path = credentials_path()?;
64 let root = read_credentials_root(&path)?;
65 let Some(oauth) = root.get("oauth") else {
66 return Ok(None);
67 };
68 serde_json::from_value::<OAuthTokenSet>(oauth.clone())
69 .map(Some)
70 .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))
71}
72
73pub fn save_oauth_credentials(token_set: &OAuthTokenSet) -> io::Result<()> {
74 let path = credentials_path()?;
75 let mut root = read_credentials_root(&path)?;
76 root.insert(
77 "oauth".to_string(),
78 serde_json::to_value(token_set).map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?,
79 );
80 write_credentials_root(&path, &root)
81}
82
83pub fn clear_oauth_credentials() -> io::Result<()> {
84 let path = credentials_path()?;
85 let mut root = read_credentials_root(&path)?;
86 root.remove("oauth");
87 write_credentials_root(&path, &root)
88}
89
90fn credentials_home_dir() -> io::Result<PathBuf> {
91 if let Some(path) = std::env::var_os("TERNLANG_CONFIG_HOME") {
92 return Ok(PathBuf::from(path));
93 }
94 let home = std::env::var_os("HOME")
95 .ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "HOME is not set"))?;
96 Ok(PathBuf::from(home).join(".ternlang"))
97}
98
99fn read_credentials_root(path: &PathBuf) -> io::Result<Map<String, Value>> {
100 match fs::read_to_string(path) {
101 Ok(contents) => {
102 if contents.trim().is_empty() {
103 return Ok(Map::new());
104 }
105 serde_json::from_str::<Value>(&contents)
106 .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?
107 .as_object()
108 .cloned()
109 .ok_or_else(|| {
110 io::Error::new(
111 io::ErrorKind::InvalidData,
112 "credentials file must contain a JSON object",
113 )
114 })
115 }
116 Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(Map::new()),
117 Err(error) => Err(error),
118 }
119}
120
121fn write_credentials_root(path: &PathBuf, root: &Map<String, Value>) -> io::Result<()> {
122 if let Some(parent) = path.parent() {
123 fs::create_dir_all(parent)?;
124 }
125 let rendered = serde_json::to_string_pretty(&Value::Object(root.clone()))
126 .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
127 let temp_path = path.with_extension("json.tmp");
128 fs::write(&temp_path, format!("{rendered}\n"))?;
129 fs::rename(temp_path, path)
130}
131
132pub fn generate_pkce_pair() -> io::Result<PkceCodePair> { Ok(PkceCodePair { verifier: "".to_string(), challenge: "".to_string(), challenge_method: PkceChallengeMethod::S256 }) }
134pub fn generate_state() -> io::Result<String> { Ok("".to_string()) }
135pub fn loopback_redirect_uri(port: u16) -> String { format!("http://localhost:{port}/callback") }
136pub fn parse_oauth_callback_request_target(_target: &str) -> Result<OAuthCallbackParams, String> { Ok(OAuthCallbackParams { code: None, state: None, error: None, error_description: None }) }
137
138#[derive(Debug, Clone, PartialEq, Eq)]
139pub struct PkceCodePair { pub verifier: String, pub challenge: String, pub challenge_method: PkceChallengeMethod }
140#[derive(Debug, Clone, Copy, PartialEq, Eq)]
141pub enum PkceChallengeMethod { S256 }
142impl PkceChallengeMethod { pub const fn as_str(self) -> &'static str { "S256" } }
143#[derive(Debug, Clone, PartialEq, Eq)]
144pub struct OAuthAuthorizationRequest { pub authorize_url: String, pub client_id: String, pub redirect_uri: String, pub scopes: Vec<String>, pub state: String, pub code_challenge: String, pub code_challenge_method: PkceChallengeMethod, pub extra_params: BTreeMap<String, String> }
145#[derive(Debug, Clone, PartialEq, Eq)]
146pub struct OAuthTokenExchangeRequest { pub grant_type: &'static str, pub code: String, pub redirect_uri: String, pub client_id: String, pub code_verifier: String, pub state: String }
147#[derive(Debug, Clone, PartialEq, Eq)]
148pub struct OAuthRefreshRequest { pub grant_type: &'static str, pub refresh_token: String, pub client_id: String, pub scopes: Vec<String> }
149#[derive(Debug, Clone, PartialEq, Eq)]
150pub struct OAuthCallbackParams { pub code: Option<String>, pub state: Option<String>, pub error: Option<String>, pub error_description: Option<String> }
151impl OAuthAuthorizationRequest { pub fn from_config(_config: &OAuthConfig, _redirect_uri: String, _state: String, _pkce: &PkceCodePair) -> Self { Self { authorize_url: "".to_string(), client_id: "".to_string(), redirect_uri: "".to_string(), scopes: vec![], state: "".to_string(), code_challenge: "".to_string(), code_challenge_method: PkceChallengeMethod::S256, extra_params: BTreeMap::new() } } pub fn build_url(&self) -> String { "".to_string() } }
152impl OAuthTokenExchangeRequest { pub fn from_config(_config: &OAuthConfig, _code: String, _state: String, _verifier: String, _redirect_uri: String) -> Self { Self { grant_type: "authorization_code", code: "".to_string(), redirect_uri: "".to_string(), client_id: "".to_string(), code_verifier: "".to_string(), state: "".to_string() } } }
153
154pub fn code_challenge_s256(verifier: &str) -> String {
155 let mut hasher = Sha256::new();
156 hasher.update(verifier.as_bytes());
157 let hash = hasher.finalize();
158 hash.iter().map(|b| format!("{b:02x}")).collect()
160}
161
162pub fn parse_oauth_callback_query(_query: &str) -> Result<OAuthCallbackParams, String> {
163 Ok(OAuthCallbackParams { code: None, state: None, error: None, error_description: None })
164}