bamboo_infrastructure/llm/provider.rs
1//! LLM provider trait and types
2//!
3//! This module defines the interface for LLM (Large Language Model) providers,
4//! enabling support for multiple LLM backends through a common trait.
5
6use crate::llm::types::LLMChunk;
7use async_trait::async_trait;
8use bamboo_domain::Message;
9use bamboo_domain::ReasoningEffort;
10use bamboo_domain::ToolSchema;
11use futures::Stream;
12use std::pin::Pin;
13use thiserror::Error;
14
15/// Errors that can occur when working with LLM providers
16#[derive(Error, Debug)]
17pub enum LLMError {
18 /// HTTP request/response errors
19 #[error("HTTP error: {0}")]
20 Http(#[from] reqwest::Error),
21
22 /// JSON serialization/deserialization errors
23 #[error("JSON error: {0}")]
24 Json(#[from] serde_json::Error),
25
26 /// Streaming response errors
27 #[error("Stream error: {0}")]
28 Stream(String),
29
30 /// LLM API errors (rate limits, invalid requests, etc.)
31 #[error("API error: {0}")]
32 Api(String),
33
34 /// Authentication/authorization errors
35 #[error("Authentication error: {0}")]
36 Auth(String),
37
38 /// Protocol conversion errors
39 #[error("Protocol conversion error: {0}")]
40 Protocol(#[from] crate::llm::protocol::ProtocolError),
41}
42
43/// Convenient result type for LLM operations
44pub type Result<T> = std::result::Result<T, LLMError>;
45
46/// Type alias for boxed streaming LLM responses
47pub type LLMStream = Pin<Box<dyn Stream<Item = Result<LLMChunk>> + Send>>;
48
49/// Metadata for a provider model returned by `list_model_info`.
50#[derive(Debug, Clone, PartialEq, Eq)]
51pub struct ProviderModelInfo {
52 /// Model identifier.
53 pub id: String,
54 /// Maximum context window (input + output) in tokens when known.
55 pub max_context_tokens: Option<u32>,
56 /// Maximum output/completion tokens when known.
57 pub max_output_tokens: Option<u32>,
58}
59
60impl ProviderModelInfo {
61 /// Create metadata with only model id (no token limits).
62 pub fn from_id(id: impl Into<String>) -> Self {
63 Self {
64 id: id.into(),
65 max_context_tokens: None,
66 max_output_tokens: None,
67 }
68 }
69}
70
71/// Optional request-time controls for provider calls.
72#[derive(Debug, Clone, Default)]
73pub struct ResponsesRequestOptions {
74 /// Optional top-level instructions for Responses API requests.
75 pub instructions: Option<String>,
76 /// Optional message list to serialize into the Responses API `input` array.
77 ///
78 /// When omitted, providers fall back to the generic `messages` slice passed
79 /// to `chat_stream_with_options`. This lets the engine provide a
80 /// Responses-specific input view (for example, without a duplicated stable
81 /// system message) while preserving backward compatibility for non-Responses
82 /// callers and providers.
83 pub input_messages: Option<Vec<Message>>,
84 /// Optional reasoning summary control for Responses API requests
85 /// (e.g. "auto", "concise", "detailed").
86 pub reasoning_summary: Option<String>,
87 /// Optional include list for Responses API requests.
88 pub include: Option<Vec<String>>,
89 /// Whether Responses API should store the response server-side.
90 pub store: Option<bool>,
91 /// Optional continuation handle for stateful Responses API turns.
92 pub previous_response_id: Option<String>,
93 /// Optional truncation mode for Responses API requests
94 /// (e.g. "auto", "disabled").
95 pub truncation: Option<String>,
96 /// Optional text verbosity for Responses API requests
97 /// (e.g. "low", "medium", "high").
98 pub text_verbosity: Option<String>,
99}
100
101/// Optional request-time controls for provider calls.
102#[derive(Debug, Clone, Default)]
103pub struct LLMRequestOptions {
104 /// Session identifier used for request-scoped logging correlation.
105 pub session_id: Option<String>,
106 /// Override reasoning effort for this request.
107 pub reasoning_effort: Option<ReasoningEffort>,
108 /// Request provider-side parallel tool call planning when supported.
109 ///
110 /// - OpenAI/Copilot: maps to `parallel_tool_calls`
111 /// - Anthropic: maps to `tool_choice.disable_parallel_tool_use` (inverse)
112 pub parallel_tool_calls: Option<bool>,
113 /// Responses API specific overrides.
114 pub responses: Option<ResponsesRequestOptions>,
115 /// Purpose of this request for observability (e.g., "agent_loop", "task_evaluation").
116 pub request_purpose: Option<String>,
117 /// Provider-agnostic prompt-cache plan describing the stable, cacheable
118 /// prefix of this request. Providers render it in their own dialect
119 /// (Anthropic `cache_control` breakpoints; OpenAI/Gemini rely on the stable
120 /// prefix automatically). `None` means "no explicit cache hints".
121 pub cache: Option<crate::llm::cache::PromptCachePlan>,
122}
123
124/// Trait for LLM provider implementations
125///
126/// This trait defines the interface that all LLM providers must implement
127/// to work with Bamboo's agent system. Providers handle communication with
128/// specific LLM services (OpenAI, Anthropic, local models, etc.).
129///
130/// # Design Principle
131///
132/// The `model` parameter is **required** in `chat_stream`, not optional.
133/// This ensures that the calling code explicitly specifies which model to use,
134/// preventing accidental use of unintended models and making model selection
135/// explicit and auditable.
136///
137/// # Example
138///
139/// ```ignore
140/// use bamboo_agent::agent::llm::provider::LLMProvider;
141///
142/// async fn use_provider(provider: &dyn LLMProvider) {
143/// let stream = provider.chat_stream(
144/// &messages,
145/// &tools,
146/// Some(4096),
147/// "claude-sonnet-4-6", // Model is required
148/// ).await?;
149/// }
150/// ```
151#[async_trait]
152pub trait LLMProvider: Send + Sync {
153 /// Stream chat completion from the LLM
154 ///
155 /// This is the primary method for interacting with LLMs, returning
156 /// a stream of response chunks that can be processed incrementally.
157 ///
158 /// # Arguments
159 ///
160 /// * `messages` - Conversation history and current prompt
161 /// * `tools` - Available tools the LLM can call
162 /// * `max_output_tokens` - Optional limit on response length
163 /// * `model` - **Required** model identifier (e.g., "claude-sonnet-4-6")
164 ///
165 /// # Returns
166 ///
167 /// A stream of `LLMChunk` items containing partial responses
168 ///
169 /// # Errors
170 ///
171 /// Returns `LLMError` on network failures, API errors, or invalid requests
172 async fn chat_stream(
173 &self,
174 messages: &[Message],
175 tools: &[ToolSchema],
176 max_output_tokens: Option<u32>,
177 model: &str,
178 ) -> Result<LLMStream>;
179
180 /// Stream chat completion with optional request-level controls.
181 ///
182 /// Default implementation preserves backward compatibility by delegating to
183 /// [`LLMProvider::chat_stream`].
184 async fn chat_stream_with_options(
185 &self,
186 messages: &[Message],
187 tools: &[ToolSchema],
188 max_output_tokens: Option<u32>,
189 model: &str,
190 _options: Option<&LLMRequestOptions>,
191 ) -> Result<LLMStream> {
192 self.chat_stream(messages, tools, max_output_tokens, model)
193 .await
194 }
195
196 /// Lists available models from this provider
197 ///
198 /// Returns a list of model identifiers that can be used with `chat_stream`.
199 /// Default implementation returns an empty list.
200 async fn list_models(&self) -> Result<Vec<String>> {
201 // Default implementation returns empty list
202 Ok(vec![])
203 }
204
205 /// Lists available models with optional token limit metadata.
206 ///
207 /// Default implementation preserves backward compatibility by adapting
208 /// `list_models()` output into metadata entries without limits.
209 async fn list_model_info(&self) -> Result<Vec<ProviderModelInfo>> {
210 Ok(self
211 .list_models()
212 .await?
213 .into_iter()
214 .map(ProviderModelInfo::from_id)
215 .collect())
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use std::sync::{Arc, Mutex};
222
223 use async_trait::async_trait;
224 use futures::{stream, StreamExt};
225
226 use super::*;
227
228 #[derive(Clone, Default)]
229 struct RecordingProvider {
230 requested_models: Arc<Mutex<Vec<String>>>,
231 requested_max_tokens: Arc<Mutex<Vec<Option<u32>>>>,
232 }
233
234 #[async_trait]
235 impl LLMProvider for RecordingProvider {
236 async fn chat_stream(
237 &self,
238 _messages: &[Message],
239 _tools: &[ToolSchema],
240 max_output_tokens: Option<u32>,
241 model: &str,
242 ) -> Result<LLMStream> {
243 if let Ok(mut models) = self.requested_models.lock() {
244 models.push(model.to_string());
245 }
246 if let Ok(mut max_tokens) = self.requested_max_tokens.lock() {
247 max_tokens.push(max_output_tokens);
248 }
249
250 Ok(Box::pin(stream::empty()))
251 }
252 }
253
254 #[tokio::test]
255 async fn chat_stream_with_options_delegates_to_chat_stream_with_same_model_and_tokens() {
256 let provider = RecordingProvider::default();
257 let options = LLMRequestOptions::default();
258
259 let mut stream = provider
260 .chat_stream_with_options(&[], &[], Some(512), "gpt-test", Some(&options))
261 .await
262 .expect("delegation should succeed");
263 assert!(stream.next().await.is_none());
264
265 assert_eq!(
266 provider
267 .requested_models
268 .lock()
269 .expect("lock poisoned")
270 .as_slice(),
271 ["gpt-test"]
272 );
273 assert_eq!(
274 provider
275 .requested_max_tokens
276 .lock()
277 .expect("lock poisoned")
278 .as_slice(),
279 [Some(512)]
280 );
281 }
282
283 #[tokio::test]
284 async fn list_models_returns_empty_by_default() {
285 let provider = RecordingProvider::default();
286 let models = provider
287 .list_models()
288 .await
289 .expect("default list_models should succeed");
290 assert!(models.is_empty());
291 }
292
293 #[test]
294 fn request_options_default_has_no_purpose() {
295 let opts = LLMRequestOptions::default();
296 assert!(opts.request_purpose.is_none());
297 }
298
299 #[test]
300 fn request_options_purpose_is_set_and_readable() {
301 let opts = LLMRequestOptions {
302 request_purpose: Some("title_generation".to_string()),
303 ..Default::default()
304 };
305 assert_eq!(opts.request_purpose.as_deref(), Some("title_generation"));
306 }
307}