bamboo_infrastructure/llm/provider.rs
1//! LLM provider trait and types
2//!
3//! This module defines the interface for LLM (Large Language Model) providers,
4//! enabling support for multiple LLM backends through a common trait.
5
6use crate::llm::types::LLMChunk;
7use async_trait::async_trait;
8use bamboo_domain::Message;
9use bamboo_domain::ReasoningEffort;
10use bamboo_domain::ToolSchema;
11use futures::Stream;
12use std::pin::Pin;
13use thiserror::Error;
14
15/// Errors that can occur when working with LLM providers
16#[derive(Error, Debug)]
17pub enum LLMError {
18 /// HTTP request/response errors
19 #[error("HTTP error: {0}")]
20 Http(#[from] reqwest::Error),
21
22 /// JSON serialization/deserialization errors
23 #[error("JSON error: {0}")]
24 Json(#[from] serde_json::Error),
25
26 /// Streaming response errors
27 #[error("Stream error: {0}")]
28 Stream(String),
29
30 /// LLM API errors (rate limits, invalid requests, etc.)
31 #[error("API error: {0}")]
32 Api(String),
33
34 /// Authentication/authorization errors
35 #[error("Authentication error: {0}")]
36 Auth(String),
37
38 /// Protocol conversion errors
39 #[error("Protocol conversion error: {0}")]
40 Protocol(#[from] crate::llm::protocol::ProtocolError),
41}
42
43/// Convenient result type for LLM operations
44pub type Result<T> = std::result::Result<T, LLMError>;
45
46/// Type alias for boxed streaming LLM responses
47pub type LLMStream = Pin<Box<dyn Stream<Item = Result<LLMChunk>> + Send>>;
48
49/// Metadata for a provider model returned by `list_model_info`.
50#[derive(Debug, Clone, PartialEq, Eq)]
51pub struct ProviderModelInfo {
52 /// Model identifier.
53 pub id: String,
54 /// Maximum context window (input + output) in tokens when known.
55 pub max_context_tokens: Option<u32>,
56 /// Maximum output/completion tokens when known.
57 pub max_output_tokens: Option<u32>,
58}
59
60impl ProviderModelInfo {
61 /// Create metadata with only model id (no token limits).
62 pub fn from_id(id: impl Into<String>) -> Self {
63 Self {
64 id: id.into(),
65 max_context_tokens: None,
66 max_output_tokens: None,
67 }
68 }
69}
70
71/// Optional request-time controls for provider calls.
72#[derive(Debug, Clone, Default)]
73pub struct ResponsesRequestOptions {
74 /// Optional top-level instructions for Responses API requests.
75 pub instructions: Option<String>,
76 /// Optional reasoning summary control for Responses API requests
77 /// (e.g. "auto", "concise", "detailed").
78 pub reasoning_summary: Option<String>,
79 /// Optional include list for Responses API requests.
80 pub include: Option<Vec<String>>,
81 /// Whether Responses API should store the response server-side.
82 pub store: Option<bool>,
83 /// Optional continuation handle for stateful Responses API turns.
84 pub previous_response_id: Option<String>,
85 /// Optional truncation mode for Responses API requests
86 /// (e.g. "auto", "disabled").
87 pub truncation: Option<String>,
88 /// Optional text verbosity for Responses API requests
89 /// (e.g. "low", "medium", "high").
90 pub text_verbosity: Option<String>,
91}
92
93/// Optional request-time controls for provider calls.
94#[derive(Debug, Clone, Default)]
95pub struct LLMRequestOptions {
96 /// Session identifier used for request-scoped logging correlation.
97 pub session_id: Option<String>,
98 /// Override reasoning effort for this request.
99 pub reasoning_effort: Option<ReasoningEffort>,
100 /// Request provider-side parallel tool call planning when supported.
101 ///
102 /// - OpenAI/Copilot: maps to `parallel_tool_calls`
103 /// - Anthropic: maps to `tool_choice.disable_parallel_tool_use` (inverse)
104 pub parallel_tool_calls: Option<bool>,
105 /// Responses API specific overrides.
106 pub responses: Option<ResponsesRequestOptions>,
107}
108
109/// Trait for LLM provider implementations
110///
111/// This trait defines the interface that all LLM providers must implement
112/// to work with Bamboo's agent system. Providers handle communication with
113/// specific LLM services (OpenAI, Anthropic, local models, etc.).
114///
115/// # Design Principle
116///
117/// The `model` parameter is **required** in `chat_stream`, not optional.
118/// This ensures that the calling code explicitly specifies which model to use,
119/// preventing accidental use of unintended models and making model selection
120/// explicit and auditable.
121///
122/// # Example
123///
124/// ```ignore
125/// use bamboo_agent::agent::llm::provider::LLMProvider;
126///
127/// async fn use_provider(provider: &dyn LLMProvider) {
128/// let stream = provider.chat_stream(
129/// &messages,
130/// &tools,
131/// Some(4096),
132/// "claude-sonnet-4-6", // Model is required
133/// ).await?;
134/// }
135/// ```
136#[async_trait]
137pub trait LLMProvider: Send + Sync {
138 /// Stream chat completion from the LLM
139 ///
140 /// This is the primary method for interacting with LLMs, returning
141 /// a stream of response chunks that can be processed incrementally.
142 ///
143 /// # Arguments
144 ///
145 /// * `messages` - Conversation history and current prompt
146 /// * `tools` - Available tools the LLM can call
147 /// * `max_output_tokens` - Optional limit on response length
148 /// * `model` - **Required** model identifier (e.g., "claude-sonnet-4-6")
149 ///
150 /// # Returns
151 ///
152 /// A stream of `LLMChunk` items containing partial responses
153 ///
154 /// # Errors
155 ///
156 /// Returns `LLMError` on network failures, API errors, or invalid requests
157 async fn chat_stream(
158 &self,
159 messages: &[Message],
160 tools: &[ToolSchema],
161 max_output_tokens: Option<u32>,
162 model: &str,
163 ) -> Result<LLMStream>;
164
165 /// Stream chat completion with optional request-level controls.
166 ///
167 /// Default implementation preserves backward compatibility by delegating to
168 /// [`LLMProvider::chat_stream`].
169 async fn chat_stream_with_options(
170 &self,
171 messages: &[Message],
172 tools: &[ToolSchema],
173 max_output_tokens: Option<u32>,
174 model: &str,
175 _options: Option<&LLMRequestOptions>,
176 ) -> Result<LLMStream> {
177 self.chat_stream(messages, tools, max_output_tokens, model)
178 .await
179 }
180
181 /// Lists available models from this provider
182 ///
183 /// Returns a list of model identifiers that can be used with `chat_stream`.
184 /// Default implementation returns an empty list.
185 async fn list_models(&self) -> Result<Vec<String>> {
186 // Default implementation returns empty list
187 Ok(vec![])
188 }
189
190 /// Lists available models with optional token limit metadata.
191 ///
192 /// Default implementation preserves backward compatibility by adapting
193 /// `list_models()` output into metadata entries without limits.
194 async fn list_model_info(&self) -> Result<Vec<ProviderModelInfo>> {
195 Ok(self
196 .list_models()
197 .await?
198 .into_iter()
199 .map(ProviderModelInfo::from_id)
200 .collect())
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use std::sync::{Arc, Mutex};
207
208 use async_trait::async_trait;
209 use futures::{stream, StreamExt};
210
211 use super::*;
212
213 #[derive(Clone, Default)]
214 struct RecordingProvider {
215 requested_models: Arc<Mutex<Vec<String>>>,
216 requested_max_tokens: Arc<Mutex<Vec<Option<u32>>>>,
217 }
218
219 #[async_trait]
220 impl LLMProvider for RecordingProvider {
221 async fn chat_stream(
222 &self,
223 _messages: &[Message],
224 _tools: &[ToolSchema],
225 max_output_tokens: Option<u32>,
226 model: &str,
227 ) -> Result<LLMStream> {
228 if let Ok(mut models) = self.requested_models.lock() {
229 models.push(model.to_string());
230 }
231 if let Ok(mut max_tokens) = self.requested_max_tokens.lock() {
232 max_tokens.push(max_output_tokens);
233 }
234
235 Ok(Box::pin(stream::empty()))
236 }
237 }
238
239 #[tokio::test]
240 async fn chat_stream_with_options_delegates_to_chat_stream_with_same_model_and_tokens() {
241 let provider = RecordingProvider::default();
242 let options = LLMRequestOptions::default();
243
244 let mut stream = provider
245 .chat_stream_with_options(&[], &[], Some(512), "gpt-test", Some(&options))
246 .await
247 .expect("delegation should succeed");
248 assert!(stream.next().await.is_none());
249
250 assert_eq!(
251 provider
252 .requested_models
253 .lock()
254 .expect("lock poisoned")
255 .as_slice(),
256 ["gpt-test"]
257 );
258 assert_eq!(
259 provider
260 .requested_max_tokens
261 .lock()
262 .expect("lock poisoned")
263 .as_slice(),
264 [Some(512)]
265 );
266 }
267
268 #[tokio::test]
269 async fn list_models_returns_empty_by_default() {
270 let provider = RecordingProvider::default();
271 let models = provider
272 .list_models()
273 .await
274 .expect("default list_models should succeed");
275 assert!(models.is_empty());
276 }
277}