1mod error;
49mod paths;
50mod protocol;
51
52pub use error::Error;
53pub use paths::{PathError, SpnPaths};
54pub use protocol::{
55 ForeignMcpInfo, IpcJobState, IpcJobStatus, IpcSchedulerStats, ModelProgress, RecentProjectInfo,
56 Request, Response, WatcherStatusInfo, PROTOCOL_VERSION,
57};
58pub use secrecy::{ExposeSecret, SecretString};
59
60pub use spn_core::{
62 find_provider,
63 mask_key,
64 provider_to_env_var,
65 providers_by_category,
66 validate_key_format,
67 BackendError,
68 ChatMessage,
70 ChatOptions,
71 ChatResponse,
72 ChatRole,
73 GpuInfo,
74 LoadConfig,
75 McpConfig,
76 McpServer,
78 McpServerType,
79 McpSource,
80 ModelInfo,
81 PackageManifest,
82 PackageRef,
84 PackageType,
85 Provider,
87 ProviderCategory,
88 PullProgress,
90 RunningModel,
91 Source,
92 ValidationResult,
94 KNOWN_PROVIDERS,
95};
96
97use std::path::PathBuf;
98use std::time::Duration;
99#[cfg(unix)]
100use tokio::io::{AsyncReadExt, AsyncWriteExt};
101#[cfg(unix)]
102use tokio::net::UnixStream;
103use tracing::debug;
104#[cfg(unix)]
105use tracing::warn;
106
107pub const DEFAULT_IPC_TIMEOUT: Duration = Duration::from_secs(30);
109
110pub fn socket_path() -> Result<PathBuf, Error> {
117 SpnPaths::new().map(|p| p.socket_file()).map_err(|_| {
118 Error::Configuration("HOME directory not found. Set HOME environment variable.".into())
119 })
120}
121
122pub fn daemon_socket_exists() -> bool {
126 socket_path().map(|p| p.exists()).unwrap_or(false)
127}
128
129#[derive(Debug)]
137pub struct SpnClient {
138 #[cfg(unix)]
139 stream: Option<UnixStream>,
140 fallback_mode: bool,
141 timeout: Duration,
143}
144
145impl SpnClient {
146 #[cfg(unix)]
152 pub async fn connect() -> Result<Self, Error> {
153 let path = socket_path()?;
154 Self::connect_to(&path).await
155 }
156
157 #[cfg(unix)]
161 pub async fn connect_to(socket_path: &PathBuf) -> Result<Self, Error> {
162 debug!("Connecting to spn daemon at {:?}", socket_path);
163
164 let stream =
165 UnixStream::connect(socket_path)
166 .await
167 .map_err(|e| Error::ConnectionFailed {
168 path: socket_path.clone(),
169 source: e,
170 })?;
171
172 let mut client = Self {
174 stream: Some(stream),
175 fallback_mode: false,
176 timeout: DEFAULT_IPC_TIMEOUT,
177 };
178
179 client.ping().await?;
180 debug!("Connected to spn daemon");
181
182 Ok(client)
183 }
184
185 pub fn set_timeout(&mut self, timeout: Duration) {
189 self.timeout = timeout;
190 }
191
192 pub fn timeout(&self) -> Duration {
194 self.timeout
195 }
196
197 #[cfg(unix)]
204 pub async fn connect_with_fallback() -> Result<Self, Error> {
205 match Self::connect().await {
206 Ok(client) => Ok(client),
207 Err(e) => {
208 warn!("spn daemon not running, using env var fallback: {}", e);
209 Ok(Self {
210 stream: None,
211 fallback_mode: true,
212 timeout: DEFAULT_IPC_TIMEOUT,
213 })
214 }
215 }
216 }
217
218 #[cfg(not(unix))]
223 pub async fn connect_with_fallback() -> Result<Self, Error> {
224 debug!("Non-Unix platform: using env var fallback mode");
225 Ok(Self {
226 fallback_mode: true,
227 timeout: DEFAULT_IPC_TIMEOUT,
228 })
229 }
230
231 pub fn is_fallback_mode(&self) -> bool {
233 self.fallback_mode
234 }
235
236 #[cfg(unix)]
243 pub async fn ping(&mut self) -> Result<String, Error> {
244 let response = self.send_request(Request::Ping).await?;
245 match response {
246 Response::Pong {
247 protocol_version,
248 version,
249 } => {
250 if protocol_version != protocol::PROTOCOL_VERSION {
252 warn!(
253 "Protocol version mismatch: client v{}, daemon v{}. \
254 Consider updating your daemon with 'spn daemon restart'.",
255 protocol::PROTOCOL_VERSION,
256 protocol_version
257 );
258 }
259 Ok(version)
260 }
261 Response::Error { message } => Err(Error::DaemonError(message)),
262 _ => Err(Error::UnexpectedResponse),
263 }
264 }
265
266 #[cfg(unix)]
271 pub async fn get_secret(&mut self, provider: &str) -> Result<SecretString, Error> {
272 if self.fallback_mode {
273 return self.get_secret_from_env(provider);
274 }
275
276 let response = self
277 .send_request(Request::GetSecret {
278 provider: provider.to_string(),
279 })
280 .await?;
281
282 match response {
283 Response::Secret { value } => Ok(SecretString::from(value)),
284 Response::Error { message } => Err(Error::SecretNotFound {
285 provider: provider.to_string(),
286 details: message,
287 }),
288 _ => Err(Error::UnexpectedResponse),
289 }
290 }
291
292 #[cfg(not(unix))]
296 pub async fn get_secret(&mut self, provider: &str) -> Result<SecretString, Error> {
297 self.get_secret_from_env(provider)
298 }
299
300 #[cfg(unix)]
302 pub async fn has_secret(&mut self, provider: &str) -> Result<bool, Error> {
303 if self.fallback_mode {
304 return Ok(self.get_secret_from_env(provider).is_ok());
305 }
306
307 let response = self
308 .send_request(Request::HasSecret {
309 provider: provider.to_string(),
310 })
311 .await?;
312
313 match response {
314 Response::Exists { exists } => Ok(exists),
315 Response::Error { message } => Err(Error::DaemonError(message)),
316 _ => Err(Error::UnexpectedResponse),
317 }
318 }
319
320 #[cfg(not(unix))]
324 pub async fn has_secret(&mut self, provider: &str) -> Result<bool, Error> {
325 Ok(self.get_secret_from_env(provider).is_ok())
326 }
327
328 #[cfg(unix)]
330 pub async fn list_providers(&mut self) -> Result<Vec<String>, Error> {
331 if self.fallback_mode {
332 return Ok(self.list_env_providers());
333 }
334
335 let response = self.send_request(Request::ListProviders).await?;
336
337 match response {
338 Response::Providers { providers } => Ok(providers),
339 Response::Error { message } => Err(Error::DaemonError(message)),
340 _ => Err(Error::UnexpectedResponse),
341 }
342 }
343
344 #[cfg(not(unix))]
348 pub async fn list_providers(&mut self) -> Result<Vec<String>, Error> {
349 Ok(self.list_env_providers())
350 }
351
352 #[cfg(unix)]
357 pub async fn refresh_secret(&mut self, provider: &str) -> Result<bool, Error> {
358 if self.fallback_mode {
359 return Ok(true);
361 }
362
363 let response = self
364 .send_request(Request::RefreshSecret {
365 provider: provider.to_string(),
366 })
367 .await?;
368
369 match response {
370 Response::Refreshed { refreshed, .. } => Ok(refreshed),
371 Response::Error { message } => Err(Error::DaemonError(message)),
372 _ => Err(Error::UnexpectedResponse),
373 }
374 }
375
376 #[cfg(not(unix))]
380 pub async fn refresh_secret(&mut self, _provider: &str) -> Result<bool, Error> {
381 Ok(true) }
383
384 #[cfg(unix)]
388 pub async fn watcher_status(&mut self) -> Result<WatcherStatusInfo, Error> {
389 if self.fallback_mode {
390 return Err(Error::DaemonError(
391 "Watcher status not available in fallback mode".into(),
392 ));
393 }
394
395 let response = self.send_request(Request::WatcherStatus).await?;
396
397 match response {
398 Response::WatcherStatusResult { status } => Ok(status),
399 Response::Error { message } => Err(Error::DaemonError(message)),
400 _ => Err(Error::UnexpectedResponse),
401 }
402 }
403
404 #[cfg(not(unix))]
408 pub async fn watcher_status(&mut self) -> Result<WatcherStatusInfo, Error> {
409 Err(Error::DaemonError(
410 "Watcher status not available on non-Unix platforms".into(),
411 ))
412 }
413
414 #[cfg(unix)]
421 pub async fn send_request(&mut self, request: Request) -> Result<Response, Error> {
422 let timeout_duration = self.timeout;
423 let timeout_secs = timeout_duration.as_secs();
424
425 tokio::time::timeout(timeout_duration, self.send_request_inner(request))
427 .await
428 .map_err(|_| Error::Timeout(timeout_secs))?
429 }
430
431 #[cfg(unix)]
433 async fn send_request_inner(&mut self, request: Request) -> Result<Response, Error> {
434 let stream = self.stream.as_mut().ok_or(Error::NotConnected)?;
435
436 let request_json = serde_json::to_vec(&request).map_err(Error::SerializationError)?;
438
439 let len = request_json.len() as u32;
441 stream
442 .write_all(&len.to_be_bytes())
443 .await
444 .map_err(Error::IoError)?;
445 stream
446 .write_all(&request_json)
447 .await
448 .map_err(Error::IoError)?;
449
450 let mut len_buf = [0u8; 4];
452 stream
453 .read_exact(&mut len_buf)
454 .await
455 .map_err(Error::IoError)?;
456 let response_len = u32::from_be_bytes(len_buf) as usize;
457
458 if response_len > 1_048_576 {
460 return Err(Error::ResponseTooLarge(response_len));
461 }
462
463 let mut response_buf = vec![0u8; response_len];
465 stream
466 .read_exact(&mut response_buf)
467 .await
468 .map_err(Error::IoError)?;
469
470 let response: Response =
472 serde_json::from_slice(&response_buf).map_err(Error::DeserializationError)?;
473
474 Ok(response)
475 }
476
477 fn get_secret_from_env(&self, provider: &str) -> Result<SecretString, Error> {
480 let env_var = provider_to_env_var(provider).ok_or_else(|| Error::SecretNotFound {
481 provider: provider.to_string(),
482 details: format!("Unknown provider: {provider}"),
483 })?;
484 std::env::var(env_var)
485 .map(SecretString::from)
486 .map_err(|_| Error::SecretNotFound {
487 provider: provider.to_string(),
488 details: format!("Environment variable {env_var} not set"),
489 })
490 }
491
492 fn list_env_providers(&self) -> Vec<String> {
493 KNOWN_PROVIDERS
494 .iter()
495 .filter(|p| std::env::var(p.env_var).is_ok())
496 .map(|p| p.id.to_string())
497 .collect()
498 }
499}
500
501#[cfg(test)]
502mod tests {
503 use super::*;
504
505 #[test]
506 fn test_provider_to_env_var() {
507 assert_eq!(provider_to_env_var("anthropic"), Some("ANTHROPIC_API_KEY"));
509 assert_eq!(provider_to_env_var("openai"), Some("OPENAI_API_KEY"));
510 assert_eq!(provider_to_env_var("neo4j"), Some("NEO4J_PASSWORD"));
511 assert_eq!(provider_to_env_var("github"), Some("GITHUB_TOKEN"));
512 assert_eq!(provider_to_env_var("unknown"), None);
513 }
514
515 #[test]
516 fn test_socket_path() {
517 if let Ok(path) = socket_path() {
519 assert!(path.to_string_lossy().contains(".spn"));
520 assert!(path.to_string_lossy().contains("daemon.sock"));
521 }
522 }
523
524 #[test]
525 fn test_daemon_socket_exists() {
526 let _exists = daemon_socket_exists();
529 }
530}