1mod error;
49mod protocol;
50
51pub use error::Error;
52pub use protocol::{Request, Response};
53pub use secrecy::{ExposeSecret, SecretString};
54
55pub use spn_core::{
57 Provider, ProviderCategory, KNOWN_PROVIDERS,
59 find_provider, provider_to_env_var, providers_by_category,
60 ValidationResult, validate_key_format, mask_key,
62 McpServer, McpServerType, McpConfig, McpSource,
64 PackageRef, PackageManifest, PackageType,
66 PullProgress, ModelInfo, RunningModel, GpuInfo, LoadConfig, BackendError,
68};
69
70use std::path::PathBuf;
71use tokio::io::{AsyncReadExt, AsyncWriteExt};
72use tokio::net::UnixStream;
73use tracing::{debug, warn};
74
75pub fn default_socket_path() -> PathBuf {
77 dirs::home_dir()
78 .map(|h| h.join(".spn").join("daemon.sock"))
79 .unwrap_or_else(|| PathBuf::from("/tmp/spn-daemon.sock"))
80}
81
82pub fn daemon_socket_exists() -> bool {
84 default_socket_path().exists()
85}
86
87#[derive(Debug)]
92pub struct SpnClient {
93 stream: Option<UnixStream>,
94 fallback_mode: bool,
95}
96
97impl SpnClient {
98 pub async fn connect() -> Result<Self, Error> {
102 Self::connect_to(&default_socket_path()).await
103 }
104
105 pub async fn connect_to(socket_path: &PathBuf) -> Result<Self, Error> {
107 debug!("Connecting to spn daemon at {:?}", socket_path);
108
109 let stream = UnixStream::connect(socket_path)
110 .await
111 .map_err(|e| Error::ConnectionFailed {
112 path: socket_path.clone(),
113 source: e,
114 })?;
115
116 let mut client = Self {
118 stream: Some(stream),
119 fallback_mode: false,
120 };
121
122 client.ping().await?;
123 debug!("Connected to spn daemon");
124
125 Ok(client)
126 }
127
128 pub async fn connect_with_fallback() -> Result<Self, Error> {
133 match Self::connect().await {
134 Ok(client) => Ok(client),
135 Err(e) => {
136 warn!("spn daemon not running, using env var fallback: {}", e);
137 Ok(Self {
138 stream: None,
139 fallback_mode: true,
140 })
141 }
142 }
143 }
144
145 pub fn is_fallback_mode(&self) -> bool {
147 self.fallback_mode
148 }
149
150 pub async fn ping(&mut self) -> Result<String, Error> {
152 let response = self.send_request(Request::Ping).await?;
153 match response {
154 Response::Pong { version } => Ok(version),
155 Response::Error { message } => Err(Error::DaemonError(message)),
156 _ => Err(Error::UnexpectedResponse),
157 }
158 }
159
160 pub async fn get_secret(&mut self, provider: &str) -> Result<SecretString, Error> {
165 if self.fallback_mode {
166 return self.get_secret_from_env(provider);
167 }
168
169 let response = self
170 .send_request(Request::GetSecret {
171 provider: provider.to_string(),
172 })
173 .await?;
174
175 match response {
176 Response::Secret { value } => Ok(SecretString::from(value)),
177 Response::Error { message } => Err(Error::SecretNotFound {
178 provider: provider.to_string(),
179 details: message,
180 }),
181 _ => Err(Error::UnexpectedResponse),
182 }
183 }
184
185 pub async fn has_secret(&mut self, provider: &str) -> Result<bool, Error> {
187 if self.fallback_mode {
188 return Ok(self.get_secret_from_env(provider).is_ok());
189 }
190
191 let response = self
192 .send_request(Request::HasSecret {
193 provider: provider.to_string(),
194 })
195 .await?;
196
197 match response {
198 Response::Exists { exists } => Ok(exists),
199 Response::Error { message } => Err(Error::DaemonError(message)),
200 _ => Err(Error::UnexpectedResponse),
201 }
202 }
203
204 pub async fn list_providers(&mut self) -> Result<Vec<String>, Error> {
206 if self.fallback_mode {
207 return Ok(self.list_env_providers());
208 }
209
210 let response = self.send_request(Request::ListProviders).await?;
211
212 match response {
213 Response::Providers { providers } => Ok(providers),
214 Response::Error { message } => Err(Error::DaemonError(message)),
215 _ => Err(Error::UnexpectedResponse),
216 }
217 }
218
219 async fn send_request(&mut self, request: Request) -> Result<Response, Error> {
221 let stream = self
222 .stream
223 .as_mut()
224 .ok_or(Error::NotConnected)?;
225
226 let request_json = serde_json::to_vec(&request).map_err(Error::SerializationError)?;
228
229 let len = request_json.len() as u32;
231 stream
232 .write_all(&len.to_be_bytes())
233 .await
234 .map_err(Error::IoError)?;
235 stream
236 .write_all(&request_json)
237 .await
238 .map_err(Error::IoError)?;
239
240 let mut len_buf = [0u8; 4];
242 stream
243 .read_exact(&mut len_buf)
244 .await
245 .map_err(Error::IoError)?;
246 let response_len = u32::from_be_bytes(len_buf) as usize;
247
248 if response_len > 1_048_576 {
250 return Err(Error::ResponseTooLarge(response_len));
251 }
252
253 let mut response_buf = vec![0u8; response_len];
255 stream
256 .read_exact(&mut response_buf)
257 .await
258 .map_err(Error::IoError)?;
259
260 let response: Response =
262 serde_json::from_slice(&response_buf).map_err(Error::DeserializationError)?;
263
264 Ok(response)
265 }
266
267 fn get_secret_from_env(&self, provider: &str) -> Result<SecretString, Error> {
270 let env_var = provider_to_env_var(provider).ok_or_else(|| Error::SecretNotFound {
271 provider: provider.to_string(),
272 details: format!("Unknown provider: {provider}"),
273 })?;
274 std::env::var(env_var)
275 .map(SecretString::from)
276 .map_err(|_| Error::SecretNotFound {
277 provider: provider.to_string(),
278 details: format!("Environment variable {env_var} not set"),
279 })
280 }
281
282 fn list_env_providers(&self) -> Vec<String> {
283 KNOWN_PROVIDERS
284 .iter()
285 .filter(|p| std::env::var(p.env_var).is_ok())
286 .map(|p| p.id.to_string())
287 .collect()
288 }
289}
290
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295
296 #[test]
297 fn test_provider_to_env_var() {
298 assert_eq!(provider_to_env_var("anthropic"), Some("ANTHROPIC_API_KEY"));
300 assert_eq!(provider_to_env_var("openai"), Some("OPENAI_API_KEY"));
301 assert_eq!(provider_to_env_var("neo4j"), Some("NEO4J_PASSWORD"));
302 assert_eq!(provider_to_env_var("github"), Some("GITHUB_TOKEN"));
303 assert_eq!(provider_to_env_var("unknown"), None);
304 }
305
306 #[test]
307 fn test_default_socket_path() {
308 let path = default_socket_path();
309 assert!(path.to_string_lossy().contains(".spn"));
310 assert!(path.to_string_lossy().contains("daemon.sock"));
311 }
312
313 #[test]
314 fn test_daemon_socket_exists() {
315 assert!(!daemon_socket_exists());
317 }
318}