Skip to main content

spn_client/
lib.rs

1//! # spn-client
2//!
3//! Client library for communicating with the spn daemon.
4//!
5//! This crate provides a simple interface for applications (like Nika) to securely
6//! retrieve secrets from the spn daemon without directly accessing the OS keychain.
7//!
8//! ## Usage
9//!
10//! ```rust,no_run
11//! use spn_client::{SpnClient, ExposeSecret};
12//!
13//! # async fn example() -> Result<(), spn_client::Error> {
14//! // Connect to the daemon
15//! let mut client = SpnClient::connect().await?;
16//!
17//! // Get a secret
18//! let api_key = client.get_secret("anthropic").await?;
19//! println!("Got key: {}", api_key.expose_secret());
20//!
21//! // Check if a secret exists
22//! if client.has_secret("openai").await? {
23//!     println!("OpenAI key available");
24//! }
25//!
26//! // List all providers
27//! let providers = client.list_providers().await?;
28//! println!("Available providers: {:?}", providers);
29//! # Ok(())
30//! # }
31//! ```
32//!
33//! ## Fallback Mode
34//!
35//! If the daemon is not running, the client can fall back to reading from
36//! environment variables:
37//!
38//! ```rust,no_run
39//! use spn_client::SpnClient;
40//!
41//! # async fn example() -> Result<(), spn_client::Error> {
42//! let mut client = SpnClient::connect_with_fallback().await?;
43//! // Works even if daemon is not running
44//! # Ok(())
45//! # }
46//! ```
47
48mod error;
49mod paths;
50mod protocol;
51
52pub use error::Error;
53pub use paths::{PathError, SpnPaths};
54pub use protocol::{Request, Response};
55pub use secrecy::{ExposeSecret, SecretString};
56
57// Re-export all spn-core types for convenience
58pub use spn_core::{
59    find_provider,
60    mask_key,
61    provider_to_env_var,
62    providers_by_category,
63    validate_key_format,
64    BackendError,
65    GpuInfo,
66    LoadConfig,
67    McpConfig,
68    // MCP
69    McpServer,
70    McpServerType,
71    McpSource,
72    ModelInfo,
73    PackageManifest,
74    // Registry
75    PackageRef,
76    PackageType,
77    // Providers
78    Provider,
79    ProviderCategory,
80    // Backend
81    PullProgress,
82    RunningModel,
83    Source,
84    // Validation
85    ValidationResult,
86    KNOWN_PROVIDERS,
87};
88
89use std::path::PathBuf;
90use std::time::Duration;
91#[cfg(unix)]
92use tokio::io::{AsyncReadExt, AsyncWriteExt};
93#[cfg(unix)]
94use tokio::net::UnixStream;
95use tracing::debug;
96#[cfg(unix)]
97use tracing::warn;
98
99/// Default timeout for IPC operations (30 seconds).
100pub const DEFAULT_IPC_TIMEOUT: Duration = Duration::from_secs(30);
101
102/// Get socket path for the spn daemon, returning an error if HOME is unavailable.
103///
104/// Use this function when you need to ensure a secure socket path.
105/// Returns an error instead of falling back to `/tmp`.
106///
107/// This is a convenience wrapper around `SpnPaths::new()?.socket_file()`.
108pub fn socket_path() -> Result<PathBuf, Error> {
109    SpnPaths::new().map(|p| p.socket_file()).map_err(|_| {
110        Error::Configuration("HOME directory not found. Set HOME environment variable.".into())
111    })
112}
113
114/// Check if the daemon socket exists.
115///
116/// Returns `false` if HOME directory is unavailable.
117pub fn daemon_socket_exists() -> bool {
118    socket_path().map(|p| p.exists()).unwrap_or(false)
119}
120
121/// Client for communicating with the spn daemon.
122///
123/// The client uses Unix socket IPC to communicate with the daemon,
124/// which handles all keychain access to avoid repeated auth prompts.
125///
126/// On non-Unix platforms (Windows), the client always operates in fallback mode,
127/// reading secrets from environment variables.
128#[derive(Debug)]
129pub struct SpnClient {
130    #[cfg(unix)]
131    stream: Option<UnixStream>,
132    fallback_mode: bool,
133    /// Timeout for IPC operations.
134    timeout: Duration,
135}
136
137impl SpnClient {
138    /// Connect to the spn daemon.
139    ///
140    /// Returns an error if the daemon is not running.
141    ///
142    /// This method is only available on Unix platforms.
143    #[cfg(unix)]
144    pub async fn connect() -> Result<Self, Error> {
145        let path = socket_path()?;
146        Self::connect_to(&path).await
147    }
148
149    /// Connect to the daemon at a specific socket path.
150    ///
151    /// This method is only available on Unix platforms.
152    #[cfg(unix)]
153    pub async fn connect_to(socket_path: &PathBuf) -> Result<Self, Error> {
154        debug!("Connecting to spn daemon at {:?}", socket_path);
155
156        let stream =
157            UnixStream::connect(socket_path)
158                .await
159                .map_err(|e| Error::ConnectionFailed {
160                    path: socket_path.clone(),
161                    source: e,
162                })?;
163
164        // Verify connection with ping
165        let mut client = Self {
166            stream: Some(stream),
167            fallback_mode: false,
168            timeout: DEFAULT_IPC_TIMEOUT,
169        };
170
171        client.ping().await?;
172        debug!("Connected to spn daemon");
173
174        Ok(client)
175    }
176
177    /// Set the timeout for IPC operations.
178    ///
179    /// The default timeout is 30 seconds.
180    pub fn set_timeout(&mut self, timeout: Duration) {
181        self.timeout = timeout;
182    }
183
184    /// Get the current timeout for IPC operations.
185    pub fn timeout(&self) -> Duration {
186        self.timeout
187    }
188
189    /// Connect to the daemon, falling back to env vars if daemon is unavailable.
190    ///
191    /// This is the recommended way to connect in applications that should
192    /// work even without the daemon running.
193    ///
194    /// On non-Unix platforms (Windows), this always returns a fallback client.
195    #[cfg(unix)]
196    pub async fn connect_with_fallback() -> Result<Self, Error> {
197        match Self::connect().await {
198            Ok(client) => Ok(client),
199            Err(e) => {
200                warn!("spn daemon not running, using env var fallback: {}", e);
201                Ok(Self {
202                    stream: None,
203                    fallback_mode: true,
204                    timeout: DEFAULT_IPC_TIMEOUT,
205                })
206            }
207        }
208    }
209
210    /// Connect to the daemon, falling back to env vars if daemon is unavailable.
211    ///
212    /// On non-Unix platforms (Windows), this always returns a fallback client
213    /// since Unix sockets are not available.
214    #[cfg(not(unix))]
215    pub async fn connect_with_fallback() -> Result<Self, Error> {
216        debug!("Non-Unix platform: using env var fallback mode");
217        Ok(Self {
218            fallback_mode: true,
219            timeout: DEFAULT_IPC_TIMEOUT,
220        })
221    }
222
223    /// Check if the client is in fallback mode (daemon not connected).
224    pub fn is_fallback_mode(&self) -> bool {
225        self.fallback_mode
226    }
227
228    /// Ping the daemon to verify the connection.
229    ///
230    /// This method is only available on Unix platforms.
231    #[cfg(unix)]
232    pub async fn ping(&mut self) -> Result<String, Error> {
233        let response = self.send_request(Request::Ping).await?;
234        match response {
235            Response::Pong { version } => Ok(version),
236            Response::Error { message } => Err(Error::DaemonError(message)),
237            _ => Err(Error::UnexpectedResponse),
238        }
239    }
240
241    /// Get a secret for the given provider.
242    ///
243    /// In fallback mode, attempts to read from the environment variable
244    /// associated with the provider (e.g., ANTHROPIC_API_KEY).
245    #[cfg(unix)]
246    pub async fn get_secret(&mut self, provider: &str) -> Result<SecretString, Error> {
247        if self.fallback_mode {
248            return self.get_secret_from_env(provider);
249        }
250
251        let response = self
252            .send_request(Request::GetSecret {
253                provider: provider.to_string(),
254            })
255            .await?;
256
257        match response {
258            Response::Secret { value } => Ok(SecretString::from(value)),
259            Response::Error { message } => Err(Error::SecretNotFound {
260                provider: provider.to_string(),
261                details: message,
262            }),
263            _ => Err(Error::UnexpectedResponse),
264        }
265    }
266
267    /// Get a secret for the given provider.
268    ///
269    /// On non-Unix platforms, always reads from environment variables.
270    #[cfg(not(unix))]
271    pub async fn get_secret(&mut self, provider: &str) -> Result<SecretString, Error> {
272        self.get_secret_from_env(provider)
273    }
274
275    /// Check if a secret exists for the given provider.
276    #[cfg(unix)]
277    pub async fn has_secret(&mut self, provider: &str) -> Result<bool, Error> {
278        if self.fallback_mode {
279            return Ok(self.get_secret_from_env(provider).is_ok());
280        }
281
282        let response = self
283            .send_request(Request::HasSecret {
284                provider: provider.to_string(),
285            })
286            .await?;
287
288        match response {
289            Response::Exists { exists } => Ok(exists),
290            Response::Error { message } => Err(Error::DaemonError(message)),
291            _ => Err(Error::UnexpectedResponse),
292        }
293    }
294
295    /// Check if a secret exists for the given provider.
296    ///
297    /// On non-Unix platforms, checks environment variables.
298    #[cfg(not(unix))]
299    pub async fn has_secret(&mut self, provider: &str) -> Result<bool, Error> {
300        Ok(self.get_secret_from_env(provider).is_ok())
301    }
302
303    /// List all available providers.
304    #[cfg(unix)]
305    pub async fn list_providers(&mut self) -> Result<Vec<String>, Error> {
306        if self.fallback_mode {
307            return Ok(self.list_env_providers());
308        }
309
310        let response = self.send_request(Request::ListProviders).await?;
311
312        match response {
313            Response::Providers { providers } => Ok(providers),
314            Response::Error { message } => Err(Error::DaemonError(message)),
315            _ => Err(Error::UnexpectedResponse),
316        }
317    }
318
319    /// List all available providers.
320    ///
321    /// On non-Unix platforms, lists providers from environment variables.
322    #[cfg(not(unix))]
323    pub async fn list_providers(&mut self) -> Result<Vec<String>, Error> {
324        Ok(self.list_env_providers())
325    }
326
327    /// Send a request to the daemon and receive a response.
328    ///
329    /// This is a low-level method for sending arbitrary requests.
330    /// For common operations, use the convenience methods like `get_secret()`.
331    ///
332    /// The request will time out after the configured timeout (default 30 seconds).
333    #[cfg(unix)]
334    pub async fn send_request(&mut self, request: Request) -> Result<Response, Error> {
335        let timeout_duration = self.timeout;
336        let timeout_secs = timeout_duration.as_secs();
337
338        // Wrap the entire operation in a timeout
339        tokio::time::timeout(timeout_duration, self.send_request_inner(request))
340            .await
341            .map_err(|_| Error::Timeout(timeout_secs))?
342    }
343
344    /// Inner implementation of send_request without timeout.
345    #[cfg(unix)]
346    async fn send_request_inner(&mut self, request: Request) -> Result<Response, Error> {
347        let stream = self.stream.as_mut().ok_or(Error::NotConnected)?;
348
349        // Serialize request
350        let request_json = serde_json::to_vec(&request).map_err(Error::SerializationError)?;
351
352        // Send length-prefixed message
353        let len = request_json.len() as u32;
354        stream
355            .write_all(&len.to_be_bytes())
356            .await
357            .map_err(Error::IoError)?;
358        stream
359            .write_all(&request_json)
360            .await
361            .map_err(Error::IoError)?;
362
363        // Read response length
364        let mut len_buf = [0u8; 4];
365        stream
366            .read_exact(&mut len_buf)
367            .await
368            .map_err(Error::IoError)?;
369        let response_len = u32::from_be_bytes(len_buf) as usize;
370
371        // Sanity check response length (max 1MB)
372        if response_len > 1_048_576 {
373            return Err(Error::ResponseTooLarge(response_len));
374        }
375
376        // Read response
377        let mut response_buf = vec![0u8; response_len];
378        stream
379            .read_exact(&mut response_buf)
380            .await
381            .map_err(Error::IoError)?;
382
383        // Deserialize
384        let response: Response =
385            serde_json::from_slice(&response_buf).map_err(Error::DeserializationError)?;
386
387        Ok(response)
388    }
389
390    // Fallback helpers
391
392    fn get_secret_from_env(&self, provider: &str) -> Result<SecretString, Error> {
393        let env_var = provider_to_env_var(provider).ok_or_else(|| Error::SecretNotFound {
394            provider: provider.to_string(),
395            details: format!("Unknown provider: {provider}"),
396        })?;
397        std::env::var(env_var)
398            .map(SecretString::from)
399            .map_err(|_| Error::SecretNotFound {
400                provider: provider.to_string(),
401                details: format!("Environment variable {env_var} not set"),
402            })
403    }
404
405    fn list_env_providers(&self) -> Vec<String> {
406        KNOWN_PROVIDERS
407            .iter()
408            .filter(|p| std::env::var(p.env_var).is_ok())
409            .map(|p| p.id.to_string())
410            .collect()
411    }
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417
418    #[test]
419    fn test_provider_to_env_var() {
420        // These now use spn_core::provider_to_env_var which returns Option
421        assert_eq!(provider_to_env_var("anthropic"), Some("ANTHROPIC_API_KEY"));
422        assert_eq!(provider_to_env_var("openai"), Some("OPENAI_API_KEY"));
423        assert_eq!(provider_to_env_var("neo4j"), Some("NEO4J_PASSWORD"));
424        assert_eq!(provider_to_env_var("github"), Some("GITHUB_TOKEN"));
425        assert_eq!(provider_to_env_var("unknown"), None);
426    }
427
428    #[test]
429    fn test_socket_path() {
430        // socket_path() returns Result, verify it works when HOME is set
431        if let Ok(path) = socket_path() {
432            assert!(path.to_string_lossy().contains(".spn"));
433            assert!(path.to_string_lossy().contains("daemon.sock"));
434        }
435    }
436
437    #[test]
438    fn test_daemon_socket_exists() {
439        // Just verify the function runs without panicking
440        // The result depends on whether daemon is actually running
441        let _exists = daemon_socket_exists();
442    }
443}