1mod error;
49mod protocol;
50
51pub use error::Error;
52pub use protocol::{Request, Response};
53pub use secrecy::{ExposeSecret, SecretString};
54
55pub use spn_core::{
57 Provider, ProviderCategory, KNOWN_PROVIDERS,
59 find_provider, provider_to_env_var, providers_by_category,
60 ValidationResult, validate_key_format, mask_key,
62 McpServer, McpServerType, McpConfig, McpSource,
64 PackageRef, PackageManifest, PackageType,
66 PullProgress, ModelInfo, RunningModel, GpuInfo, LoadConfig, BackendError,
68};
69
70use std::path::PathBuf;
71#[cfg(unix)]
72use tokio::io::{AsyncReadExt, AsyncWriteExt};
73#[cfg(unix)]
74use tokio::net::UnixStream;
75use tracing::debug;
76#[cfg(unix)]
77use tracing::warn;
78
79pub fn default_socket_path() -> PathBuf {
81 dirs::home_dir()
82 .map(|h| h.join(".spn").join("daemon.sock"))
83 .unwrap_or_else(|| PathBuf::from("/tmp/spn-daemon.sock"))
84}
85
86pub fn daemon_socket_exists() -> bool {
88 default_socket_path().exists()
89}
90
91#[derive(Debug)]
99pub struct SpnClient {
100 #[cfg(unix)]
101 stream: Option<UnixStream>,
102 fallback_mode: bool,
103}
104
105impl SpnClient {
106 #[cfg(unix)]
112 pub async fn connect() -> Result<Self, Error> {
113 Self::connect_to(&default_socket_path()).await
114 }
115
116 #[cfg(unix)]
120 pub async fn connect_to(socket_path: &PathBuf) -> Result<Self, Error> {
121 debug!("Connecting to spn daemon at {:?}", socket_path);
122
123 let stream = UnixStream::connect(socket_path)
124 .await
125 .map_err(|e| Error::ConnectionFailed {
126 path: socket_path.clone(),
127 source: e,
128 })?;
129
130 let mut client = Self {
132 stream: Some(stream),
133 fallback_mode: false,
134 };
135
136 client.ping().await?;
137 debug!("Connected to spn daemon");
138
139 Ok(client)
140 }
141
142 #[cfg(unix)]
149 pub async fn connect_with_fallback() -> Result<Self, Error> {
150 match Self::connect().await {
151 Ok(client) => Ok(client),
152 Err(e) => {
153 warn!("spn daemon not running, using env var fallback: {}", e);
154 Ok(Self {
155 stream: None,
156 fallback_mode: true,
157 })
158 }
159 }
160 }
161
162 #[cfg(not(unix))]
167 pub async fn connect_with_fallback() -> Result<Self, Error> {
168 debug!("Non-Unix platform: using env var fallback mode");
169 Ok(Self {
170 fallback_mode: true,
171 })
172 }
173
174 pub fn is_fallback_mode(&self) -> bool {
176 self.fallback_mode
177 }
178
179 #[cfg(unix)]
183 pub async fn ping(&mut self) -> Result<String, Error> {
184 let response = self.send_request(Request::Ping).await?;
185 match response {
186 Response::Pong { version } => Ok(version),
187 Response::Error { message } => Err(Error::DaemonError(message)),
188 _ => Err(Error::UnexpectedResponse),
189 }
190 }
191
192 #[cfg(unix)]
197 pub async fn get_secret(&mut self, provider: &str) -> Result<SecretString, Error> {
198 if self.fallback_mode {
199 return self.get_secret_from_env(provider);
200 }
201
202 let response = self
203 .send_request(Request::GetSecret {
204 provider: provider.to_string(),
205 })
206 .await?;
207
208 match response {
209 Response::Secret { value } => Ok(SecretString::from(value)),
210 Response::Error { message } => Err(Error::SecretNotFound {
211 provider: provider.to_string(),
212 details: message,
213 }),
214 _ => Err(Error::UnexpectedResponse),
215 }
216 }
217
218 #[cfg(not(unix))]
222 pub async fn get_secret(&mut self, provider: &str) -> Result<SecretString, Error> {
223 self.get_secret_from_env(provider)
224 }
225
226 #[cfg(unix)]
228 pub async fn has_secret(&mut self, provider: &str) -> Result<bool, Error> {
229 if self.fallback_mode {
230 return Ok(self.get_secret_from_env(provider).is_ok());
231 }
232
233 let response = self
234 .send_request(Request::HasSecret {
235 provider: provider.to_string(),
236 })
237 .await?;
238
239 match response {
240 Response::Exists { exists } => Ok(exists),
241 Response::Error { message } => Err(Error::DaemonError(message)),
242 _ => Err(Error::UnexpectedResponse),
243 }
244 }
245
246 #[cfg(not(unix))]
250 pub async fn has_secret(&mut self, provider: &str) -> Result<bool, Error> {
251 Ok(self.get_secret_from_env(provider).is_ok())
252 }
253
254 #[cfg(unix)]
256 pub async fn list_providers(&mut self) -> Result<Vec<String>, Error> {
257 if self.fallback_mode {
258 return Ok(self.list_env_providers());
259 }
260
261 let response = self.send_request(Request::ListProviders).await?;
262
263 match response {
264 Response::Providers { providers } => Ok(providers),
265 Response::Error { message } => Err(Error::DaemonError(message)),
266 _ => Err(Error::UnexpectedResponse),
267 }
268 }
269
270 #[cfg(not(unix))]
274 pub async fn list_providers(&mut self) -> Result<Vec<String>, Error> {
275 Ok(self.list_env_providers())
276 }
277
278 #[cfg(unix)]
280 async fn send_request(&mut self, request: Request) -> Result<Response, Error> {
281 let stream = self
282 .stream
283 .as_mut()
284 .ok_or(Error::NotConnected)?;
285
286 let request_json = serde_json::to_vec(&request).map_err(Error::SerializationError)?;
288
289 let len = request_json.len() as u32;
291 stream
292 .write_all(&len.to_be_bytes())
293 .await
294 .map_err(Error::IoError)?;
295 stream
296 .write_all(&request_json)
297 .await
298 .map_err(Error::IoError)?;
299
300 let mut len_buf = [0u8; 4];
302 stream
303 .read_exact(&mut len_buf)
304 .await
305 .map_err(Error::IoError)?;
306 let response_len = u32::from_be_bytes(len_buf) as usize;
307
308 if response_len > 1_048_576 {
310 return Err(Error::ResponseTooLarge(response_len));
311 }
312
313 let mut response_buf = vec![0u8; response_len];
315 stream
316 .read_exact(&mut response_buf)
317 .await
318 .map_err(Error::IoError)?;
319
320 let response: Response =
322 serde_json::from_slice(&response_buf).map_err(Error::DeserializationError)?;
323
324 Ok(response)
325 }
326
327 fn get_secret_from_env(&self, provider: &str) -> Result<SecretString, Error> {
330 let env_var = provider_to_env_var(provider).ok_or_else(|| Error::SecretNotFound {
331 provider: provider.to_string(),
332 details: format!("Unknown provider: {provider}"),
333 })?;
334 std::env::var(env_var)
335 .map(SecretString::from)
336 .map_err(|_| Error::SecretNotFound {
337 provider: provider.to_string(),
338 details: format!("Environment variable {env_var} not set"),
339 })
340 }
341
342 fn list_env_providers(&self) -> Vec<String> {
343 KNOWN_PROVIDERS
344 .iter()
345 .filter(|p| std::env::var(p.env_var).is_ok())
346 .map(|p| p.id.to_string())
347 .collect()
348 }
349}
350
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355
356 #[test]
357 fn test_provider_to_env_var() {
358 assert_eq!(provider_to_env_var("anthropic"), Some("ANTHROPIC_API_KEY"));
360 assert_eq!(provider_to_env_var("openai"), Some("OPENAI_API_KEY"));
361 assert_eq!(provider_to_env_var("neo4j"), Some("NEO4J_PASSWORD"));
362 assert_eq!(provider_to_env_var("github"), Some("GITHUB_TOKEN"));
363 assert_eq!(provider_to_env_var("unknown"), None);
364 }
365
366 #[test]
367 fn test_default_socket_path() {
368 let path = default_socket_path();
369 assert!(path.to_string_lossy().contains(".spn"));
370 assert!(path.to_string_lossy().contains("daemon.sock"));
371 }
372
373 #[test]
374 fn test_daemon_socket_exists() {
375 assert!(!daemon_socket_exists());
377 }
378}