Skip to main content

bamboo_infrastructure/llm/
provider.rs

1//! LLM provider trait and types
2//!
3//! This module defines the interface for LLM (Large Language Model) providers,
4//! enabling support for multiple LLM backends through a common trait.
5
6use crate::llm::types::LLMChunk;
7use async_trait::async_trait;
8use bamboo_domain::Message;
9use bamboo_domain::ReasoningEffort;
10use bamboo_domain::ToolSchema;
11use futures::Stream;
12use std::pin::Pin;
13use thiserror::Error;
14
15/// Errors that can occur when working with LLM providers
16#[derive(Error, Debug)]
17pub enum LLMError {
18    /// HTTP request/response errors
19    #[error("HTTP error: {0}")]
20    Http(#[from] reqwest::Error),
21
22    /// JSON serialization/deserialization errors
23    #[error("JSON error: {0}")]
24    Json(#[from] serde_json::Error),
25
26    /// Streaming response errors
27    #[error("Stream error: {0}")]
28    Stream(String),
29
30    /// LLM API errors (rate limits, invalid requests, etc.)
31    #[error("API error: {0}")]
32    Api(String),
33
34    /// Authentication/authorization errors
35    #[error("Authentication error: {0}")]
36    Auth(String),
37
38    /// Protocol conversion errors
39    #[error("Protocol conversion error: {0}")]
40    Protocol(#[from] crate::llm::protocol::ProtocolError),
41}
42
43/// Convenient result type for LLM operations
44pub type Result<T> = std::result::Result<T, LLMError>;
45
46/// Type alias for boxed streaming LLM responses
47pub type LLMStream = Pin<Box<dyn Stream<Item = Result<LLMChunk>> + Send>>;
48
49/// Metadata for a provider model returned by `list_model_info`.
50#[derive(Debug, Clone, PartialEq, Eq)]
51pub struct ProviderModelInfo {
52    /// Model identifier.
53    pub id: String,
54    /// Maximum context window (input + output) in tokens when known.
55    pub max_context_tokens: Option<u32>,
56    /// Maximum output/completion tokens when known.
57    pub max_output_tokens: Option<u32>,
58}
59
60impl ProviderModelInfo {
61    /// Create metadata with only model id (no token limits).
62    pub fn from_id(id: impl Into<String>) -> Self {
63        Self {
64            id: id.into(),
65            max_context_tokens: None,
66            max_output_tokens: None,
67        }
68    }
69}
70
71/// Optional request-time controls for provider calls.
72#[derive(Debug, Clone, Default)]
73pub struct ResponsesRequestOptions {
74    /// Optional top-level instructions for Responses API requests.
75    pub instructions: Option<String>,
76    /// Optional reasoning summary control for Responses API requests
77    /// (e.g. "auto", "concise", "detailed").
78    pub reasoning_summary: Option<String>,
79    /// Optional include list for Responses API requests.
80    pub include: Option<Vec<String>>,
81    /// Whether Responses API should store the response server-side.
82    pub store: Option<bool>,
83    /// Optional continuation handle for stateful Responses API turns.
84    pub previous_response_id: Option<String>,
85    /// Optional truncation mode for Responses API requests
86    /// (e.g. "auto", "disabled").
87    pub truncation: Option<String>,
88    /// Optional text verbosity for Responses API requests
89    /// (e.g. "low", "medium", "high").
90    pub text_verbosity: Option<String>,
91}
92
93/// Optional request-time controls for provider calls.
94#[derive(Debug, Clone, Default)]
95pub struct LLMRequestOptions {
96    /// Session identifier used for request-scoped logging correlation.
97    pub session_id: Option<String>,
98    /// Override reasoning effort for this request.
99    pub reasoning_effort: Option<ReasoningEffort>,
100    /// Request provider-side parallel tool call planning when supported.
101    ///
102    /// - OpenAI/Copilot: maps to `parallel_tool_calls`
103    /// - Anthropic: maps to `tool_choice.disable_parallel_tool_use` (inverse)
104    pub parallel_tool_calls: Option<bool>,
105    /// Responses API specific overrides.
106    pub responses: Option<ResponsesRequestOptions>,
107    /// Purpose of this request for observability (e.g., "agent_loop", "task_evaluation").
108    pub request_purpose: Option<String>,
109}
110
111/// Trait for LLM provider implementations
112///
113/// This trait defines the interface that all LLM providers must implement
114/// to work with Bamboo's agent system. Providers handle communication with
115/// specific LLM services (OpenAI, Anthropic, local models, etc.).
116///
117/// # Design Principle
118///
119/// The `model` parameter is **required** in `chat_stream`, not optional.
120/// This ensures that the calling code explicitly specifies which model to use,
121/// preventing accidental use of unintended models and making model selection
122/// explicit and auditable.
123///
124/// # Example
125///
126/// ```ignore
127/// use bamboo_agent::agent::llm::provider::LLMProvider;
128///
129/// async fn use_provider(provider: &dyn LLMProvider) {
130///     let stream = provider.chat_stream(
131///         &messages,
132///         &tools,
133///         Some(4096),
134///         "claude-sonnet-4-6", // Model is required
135///     ).await?;
136/// }
137/// ```
138#[async_trait]
139pub trait LLMProvider: Send + Sync {
140    /// Stream chat completion from the LLM
141    ///
142    /// This is the primary method for interacting with LLMs, returning
143    /// a stream of response chunks that can be processed incrementally.
144    ///
145    /// # Arguments
146    ///
147    /// * `messages` - Conversation history and current prompt
148    /// * `tools` - Available tools the LLM can call
149    /// * `max_output_tokens` - Optional limit on response length
150    /// * `model` - **Required** model identifier (e.g., "claude-sonnet-4-6")
151    ///
152    /// # Returns
153    ///
154    /// A stream of `LLMChunk` items containing partial responses
155    ///
156    /// # Errors
157    ///
158    /// Returns `LLMError` on network failures, API errors, or invalid requests
159    async fn chat_stream(
160        &self,
161        messages: &[Message],
162        tools: &[ToolSchema],
163        max_output_tokens: Option<u32>,
164        model: &str,
165    ) -> Result<LLMStream>;
166
167    /// Stream chat completion with optional request-level controls.
168    ///
169    /// Default implementation preserves backward compatibility by delegating to
170    /// [`LLMProvider::chat_stream`].
171    async fn chat_stream_with_options(
172        &self,
173        messages: &[Message],
174        tools: &[ToolSchema],
175        max_output_tokens: Option<u32>,
176        model: &str,
177        _options: Option<&LLMRequestOptions>,
178    ) -> Result<LLMStream> {
179        self.chat_stream(messages, tools, max_output_tokens, model)
180            .await
181    }
182
183    /// Lists available models from this provider
184    ///
185    /// Returns a list of model identifiers that can be used with `chat_stream`.
186    /// Default implementation returns an empty list.
187    async fn list_models(&self) -> Result<Vec<String>> {
188        // Default implementation returns empty list
189        Ok(vec![])
190    }
191
192    /// Lists available models with optional token limit metadata.
193    ///
194    /// Default implementation preserves backward compatibility by adapting
195    /// `list_models()` output into metadata entries without limits.
196    async fn list_model_info(&self) -> Result<Vec<ProviderModelInfo>> {
197        Ok(self
198            .list_models()
199            .await?
200            .into_iter()
201            .map(ProviderModelInfo::from_id)
202            .collect())
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use std::sync::{Arc, Mutex};
209
210    use async_trait::async_trait;
211    use futures::{stream, StreamExt};
212
213    use super::*;
214
215    #[derive(Clone, Default)]
216    struct RecordingProvider {
217        requested_models: Arc<Mutex<Vec<String>>>,
218        requested_max_tokens: Arc<Mutex<Vec<Option<u32>>>>,
219    }
220
221    #[async_trait]
222    impl LLMProvider for RecordingProvider {
223        async fn chat_stream(
224            &self,
225            _messages: &[Message],
226            _tools: &[ToolSchema],
227            max_output_tokens: Option<u32>,
228            model: &str,
229        ) -> Result<LLMStream> {
230            if let Ok(mut models) = self.requested_models.lock() {
231                models.push(model.to_string());
232            }
233            if let Ok(mut max_tokens) = self.requested_max_tokens.lock() {
234                max_tokens.push(max_output_tokens);
235            }
236
237            Ok(Box::pin(stream::empty()))
238        }
239    }
240
241    #[tokio::test]
242    async fn chat_stream_with_options_delegates_to_chat_stream_with_same_model_and_tokens() {
243        let provider = RecordingProvider::default();
244        let options = LLMRequestOptions::default();
245
246        let mut stream = provider
247            .chat_stream_with_options(&[], &[], Some(512), "gpt-test", Some(&options))
248            .await
249            .expect("delegation should succeed");
250        assert!(stream.next().await.is_none());
251
252        assert_eq!(
253            provider
254                .requested_models
255                .lock()
256                .expect("lock poisoned")
257                .as_slice(),
258            ["gpt-test"]
259        );
260        assert_eq!(
261            provider
262                .requested_max_tokens
263                .lock()
264                .expect("lock poisoned")
265                .as_slice(),
266            [Some(512)]
267        );
268    }
269
270    #[tokio::test]
271    async fn list_models_returns_empty_by_default() {
272        let provider = RecordingProvider::default();
273        let models = provider
274            .list_models()
275            .await
276            .expect("default list_models should succeed");
277        assert!(models.is_empty());
278    }
279
280    #[test]
281    fn request_options_default_has_no_purpose() {
282        let opts = LLMRequestOptions::default();
283        assert!(opts.request_purpose.is_none());
284    }
285
286    #[test]
287    fn request_options_purpose_is_set_and_readable() {
288        let opts = LLMRequestOptions {
289            request_purpose: Some("title_generation".to_string()),
290            ..Default::default()
291        };
292        assert_eq!(opts.request_purpose.as_deref(), Some("title_generation"));
293    }
294}