Skip to main content

neuron_turn/
provider.rs

1//! Provider trait for LLM backends.
2//!
3//! The [`Provider`] trait uses RPITIT (return-position `impl Trait` in traits)
4//! and is intentionally NOT object-safe. The object-safe boundary is
5//! `layer0::Turn` — NeuronTurn<P: Provider> implements Turn.
6
7use crate::types::{ProviderRequest, ProviderResponse};
8use std::future::Future;
9use thiserror::Error;
10
11/// Errors from LLM providers.
12#[non_exhaustive]
13#[derive(Debug, Error)]
14pub enum ProviderError {
15    /// HTTP or network request failed.
16    #[error("request failed: {0}")]
17    RequestFailed(String),
18
19    /// Provider rate-limited the request.
20    #[error("rate limited")]
21    RateLimited,
22
23    /// Authentication/authorization failed.
24    #[error("auth failed: {0}")]
25    AuthFailed(String),
26
27    /// Could not parse the provider's response.
28    #[error("invalid response: {0}")]
29    InvalidResponse(String),
30
31    /// Catch-all for other errors.
32    #[error("{0}")]
33    Other(#[from] Box<dyn std::error::Error + Send + Sync>),
34}
35
36impl ProviderError {
37    /// Whether retrying this request might succeed.
38    pub fn is_retryable(&self) -> bool {
39        matches!(
40            self,
41            ProviderError::RateLimited | ProviderError::RequestFailed(_)
42        )
43    }
44}
45
46/// LLM provider interface.
47///
48/// Each provider (Anthropic, OpenAI, Ollama) implements this trait.
49/// Provider-native features (truncation, caching, thinking blocks)
50/// are handled by the provider impl using `ProviderRequest.extra`.
51///
52/// This trait uses RPITIT and is NOT object-safe. That's intentional —
53/// `NeuronTurn<P: Provider>` is generic, and the object-safe boundary
54/// is `layer0::Turn`.
55pub trait Provider: Send + Sync {
56    /// Send a completion request to the provider.
57    fn complete(
58        &self,
59        request: ProviderRequest,
60    ) -> impl Future<Output = Result<ProviderResponse, ProviderError>> + Send;
61}
62
63#[cfg(test)]
64mod tests {
65    use super::*;
66
67    #[test]
68    fn provider_error_display() {
69        assert_eq!(
70            ProviderError::RequestFailed("timeout".into()).to_string(),
71            "request failed: timeout"
72        );
73        assert_eq!(ProviderError::RateLimited.to_string(), "rate limited");
74        assert_eq!(
75            ProviderError::AuthFailed("bad key".into()).to_string(),
76            "auth failed: bad key"
77        );
78        assert_eq!(
79            ProviderError::InvalidResponse("bad json".into()).to_string(),
80            "invalid response: bad json"
81        );
82    }
83
84    #[test]
85    fn provider_error_retryable() {
86        assert!(ProviderError::RateLimited.is_retryable());
87        assert!(ProviderError::RequestFailed("timeout".into()).is_retryable());
88        assert!(!ProviderError::AuthFailed("bad key".into()).is_retryable());
89        assert!(!ProviderError::InvalidResponse("x".into()).is_retryable());
90    }
91
92    #[test]
93    fn provider_error_from_boxed() {
94        let err: Box<dyn std::error::Error + Send + Sync> = "some error".into();
95        let provider_err = ProviderError::from(err);
96        assert!(matches!(provider_err, ProviderError::Other(_)));
97        assert!(!provider_err.is_retryable());
98    }
99
100    #[test]
101    fn provider_error_other_display() {
102        let err: Box<dyn std::error::Error + Send + Sync> = "custom error".into();
103        let provider_err = ProviderError::from(err);
104        assert_eq!(provider_err.to_string(), "custom error");
105    }
106}