bamboo_llm/provider.rs
1//! LLM provider trait and types
2//!
3//! This module defines the interface for LLM (Large Language Model) providers,
4//! enabling support for multiple LLM backends through a common trait.
5
6use crate::prompt_ir::PromptIR;
7use crate::types::LLMChunk;
8use async_trait::async_trait;
9use bamboo_domain::Message;
10use bamboo_domain::ReasoningEffort;
11use bamboo_domain::ToolSchema;
12use futures::Stream;
13use std::pin::Pin;
14use thiserror::Error;
15
16/// Errors that can occur when working with LLM providers
17#[derive(Error, Debug)]
18pub enum LLMError {
19 /// HTTP request/response errors
20 #[error("HTTP error: {0}")]
21 Http(#[from] reqwest::Error),
22
23 /// JSON serialization/deserialization errors
24 #[error("JSON error: {0}")]
25 Json(#[from] serde_json::Error),
26
27 /// Streaming response errors
28 #[error("Stream error: {0}")]
29 Stream(String),
30
31 /// LLM API errors (rate limits, invalid requests, etc.)
32 #[error("API error: {0}")]
33 Api(String),
34
35 /// Authentication/authorization errors
36 #[error("Authentication error: {0}")]
37 Auth(String),
38
39 /// Protocol conversion errors
40 #[error("Protocol conversion error: {0}")]
41 Protocol(#[from] crate::protocol::ProtocolError),
42}
43
44/// Convenient result type for LLM operations
45pub type Result<T> = std::result::Result<T, LLMError>;
46
47/// Type alias for boxed streaming LLM responses
48pub type LLMStream = Pin<Box<dyn Stream<Item = Result<LLMChunk>> + Send>>;
49
50/// Metadata for a provider model returned by `list_model_info`.
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub struct ProviderModelInfo {
53 /// Model identifier.
54 pub id: String,
55 /// Maximum context window (input + output) in tokens when known.
56 pub max_context_tokens: Option<u32>,
57 /// Maximum output/completion tokens when known.
58 pub max_output_tokens: Option<u32>,
59}
60
61impl ProviderModelInfo {
62 /// Create metadata with only model id (no token limits).
63 pub fn from_id(id: impl Into<String>) -> Self {
64 Self {
65 id: id.into(),
66 max_context_tokens: None,
67 max_output_tokens: None,
68 }
69 }
70}
71
72/// Optional request-time controls for provider calls.
73#[derive(Debug, Clone, Default)]
74pub struct ResponsesRequestOptions {
75 /// Optional top-level instructions for Responses API requests.
76 pub instructions: Option<String>,
77 /// Optional message list to serialize into the Responses API `input` array.
78 ///
79 /// When omitted, providers fall back to the generic `messages` slice passed
80 /// to `chat_stream_with_options`. This lets the engine provide a
81 /// Responses-specific input view (for example, without a duplicated stable
82 /// system message) while preserving backward compatibility for non-Responses
83 /// callers and providers.
84 pub input_messages: Option<Vec<Message>>,
85 /// Optional reasoning summary control for Responses API requests
86 /// (e.g. "auto", "concise", "detailed").
87 pub reasoning_summary: Option<String>,
88 /// Optional include list for Responses API requests.
89 pub include: Option<Vec<String>>,
90 /// Whether Responses API should store the response server-side.
91 pub store: Option<bool>,
92 /// Optional continuation handle for stateful Responses API turns.
93 pub previous_response_id: Option<String>,
94 /// Optional truncation mode for Responses API requests
95 /// (e.g. "auto", "disabled").
96 pub truncation: Option<String>,
97 /// Optional text verbosity for Responses API requests
98 /// (e.g. "low", "medium", "high").
99 pub text_verbosity: Option<String>,
100}
101
102/// Optional request-time controls for provider calls.
103#[derive(Debug, Clone, Default)]
104pub struct LLMRequestOptions {
105 /// Session identifier used for request-scoped logging correlation.
106 pub session_id: Option<String>,
107 /// Override reasoning effort for this request.
108 pub reasoning_effort: Option<ReasoningEffort>,
109 /// Request provider-side parallel tool call planning when supported.
110 ///
111 /// - OpenAI/Copilot: maps to `parallel_tool_calls`
112 /// - Anthropic: maps to `tool_choice.disable_parallel_tool_use` (inverse)
113 pub parallel_tool_calls: Option<bool>,
114 /// Responses API specific overrides.
115 pub responses: Option<ResponsesRequestOptions>,
116 /// Purpose of this request for observability (e.g., "agent_loop", "task_evaluation").
117 pub request_purpose: Option<String>,
118 /// Provider-agnostic prompt-cache plan describing the stable, cacheable
119 /// prefix of this request. Providers render it in their own dialect
120 /// (Anthropic `cache_control` breakpoints; OpenAI/Gemini rely on the stable
121 /// prefix automatically). `None` means "no explicit cache hints".
122 pub cache: Option<crate::cache::PromptCachePlan>,
123}
124
125/// Trait for LLM provider implementations
126///
127/// This trait defines the interface that all LLM providers must implement
128/// to work with Bamboo's agent system. Providers handle communication with
129/// specific LLM services (OpenAI, Anthropic, local models, etc.).
130///
131/// # Design Principle
132///
133/// The `model` parameter is **required** in `chat_stream`, not optional.
134/// This ensures that the calling code explicitly specifies which model to use,
135/// preventing accidental use of unintended models and making model selection
136/// explicit and auditable.
137///
138/// # Example
139///
140/// ```ignore
141/// use bamboo_agent::agent::llm::provider::LLMProvider;
142///
143/// async fn use_provider(provider: &dyn LLMProvider) {
144/// let stream = provider.chat_stream(
145/// &messages,
146/// &tools,
147/// Some(4096),
148/// "claude-sonnet-4-6", // Model is required
149/// ).await?;
150/// }
151/// ```
152#[async_trait]
153pub trait LLMProvider: Send + Sync {
154 /// Stream chat completion from the LLM
155 ///
156 /// This is the primary method for interacting with LLMs, returning
157 /// a stream of response chunks that can be processed incrementally.
158 ///
159 /// # Arguments
160 ///
161 /// * `messages` - Conversation history and current prompt
162 /// * `tools` - Available tools the LLM can call
163 /// * `max_output_tokens` - Optional limit on response length
164 /// * `model` - **Required** model identifier (e.g., "claude-sonnet-4-6")
165 ///
166 /// # Returns
167 ///
168 /// A stream of `LLMChunk` items containing partial responses
169 ///
170 /// # Errors
171 ///
172 /// Returns `LLMError` on network failures, API errors, or invalid requests
173 async fn chat_stream(
174 &self,
175 messages: &[Message],
176 tools: &[ToolSchema],
177 max_output_tokens: Option<u32>,
178 model: &str,
179 ) -> Result<LLMStream>;
180
181 /// Stream chat completion with optional request-level controls.
182 ///
183 /// Default implementation preserves backward compatibility by delegating to
184 /// [`LLMProvider::chat_stream`].
185 async fn chat_stream_with_options(
186 &self,
187 messages: &[Message],
188 tools: &[ToolSchema],
189 max_output_tokens: Option<u32>,
190 model: &str,
191 _options: Option<&LLMRequestOptions>,
192 ) -> Result<LLMStream> {
193 self.chat_stream(messages, tools, max_output_tokens, model)
194 .await
195 }
196
197 /// Stream from the canonical [`PromptIR`] — the single, rich, provider-agnostic
198 /// request the engine emits once per round.
199 ///
200 /// A provider renders the IR into its own wire format by calling the lowering
201 /// methods ([`PromptIR::system_field`], [`PromptIR::body_chat`],
202 /// [`PromptIR::responses_input`], [`PromptIR::continuation_delta`]). The IR
203 /// carries the stateful Responses continuation, so an adapter derives the
204 /// delta itself rather than the engine pre-baking it.
205 ///
206 /// The default implementation lowers the IR for BOTH wire families and
207 /// delegates to [`chat_stream_with_options`](Self::chat_stream_with_options):
208 /// - the flat message list (`continuation_delta` mid-tool-loop, else `flatten`)
209 /// for the Chat-Completions path;
210 /// - the Responses-API view (`instructions` / `input_messages` /
211 /// `previous_response_id`) derived via [`PromptIR::responses_request_options`]
212 /// and merged onto the request POLICY, so a Responses provider works WITHOUT
213 /// overriding this method (Chat-Completions providers ignore those options).
214 ///
215 /// This is byte-identical to the pre-IR request. Block-native providers (e.g.
216 /// Anthropic) still override this to consume `system_blocks` structurally.
217 async fn chat_stream_ir(
218 &self,
219 ir: &PromptIR,
220 tools: &[ToolSchema],
221 max_output_tokens: Option<u32>,
222 model: &str,
223 options: Option<&LLMRequestOptions>,
224 ) -> Result<LLMStream> {
225 let messages = if ir.continuation.is_some() {
226 ir.continuation_delta()
227 } else {
228 ir.flatten()
229 };
230 let mut effective_options = options.cloned().unwrap_or_default();
231 effective_options.responses =
232 Some(ir.responses_request_options(effective_options.responses.as_ref()));
233 self.chat_stream_with_options(
234 &messages,
235 tools,
236 max_output_tokens,
237 model,
238 Some(&effective_options),
239 )
240 .await
241 }
242
243 /// Lists available models from this provider
244 ///
245 /// Returns a list of model identifiers that can be used with `chat_stream`.
246 /// Default implementation returns an empty list.
247 async fn list_models(&self) -> Result<Vec<String>> {
248 // Default implementation returns empty list
249 Ok(vec![])
250 }
251
252 /// Lists available models with optional token limit metadata.
253 ///
254 /// Default implementation preserves backward compatibility by adapting
255 /// `list_models()` output into metadata entries without limits.
256 async fn list_model_info(&self) -> Result<Vec<ProviderModelInfo>> {
257 Ok(self
258 .list_models()
259 .await?
260 .into_iter()
261 .map(ProviderModelInfo::from_id)
262 .collect())
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use std::sync::{Arc, Mutex};
269
270 use async_trait::async_trait;
271 use futures::{stream, StreamExt};
272
273 use super::*;
274
275 #[tokio::test]
276 async fn chat_stream_ir_default_flattens_and_delegates() {
277 use crate::prompt_ir::{PromptIR, Segment, SegmentRole};
278
279 // A provider that captures the message list AND the options it is handed.
280 #[derive(Default)]
281 struct Capture {
282 seen: Arc<Mutex<Vec<Message>>>,
283 seen_responses: Arc<Mutex<Option<crate::provider::ResponsesRequestOptions>>>,
284 }
285 #[async_trait]
286 impl LLMProvider for Capture {
287 async fn chat_stream(
288 &self,
289 _m: &[Message],
290 _t: &[ToolSchema],
291 _mt: Option<u32>,
292 _model: &str,
293 ) -> Result<LLMStream> {
294 unreachable!("default chat_stream_ir must route via chat_stream_with_options")
295 }
296 async fn chat_stream_with_options(
297 &self,
298 messages: &[Message],
299 _t: &[ToolSchema],
300 _mt: Option<u32>,
301 _model: &str,
302 o: Option<&LLMRequestOptions>,
303 ) -> Result<LLMStream> {
304 *self.seen.lock().expect("seen lock") = messages.to_vec();
305 *self.seen_responses.lock().expect("resp lock") =
306 o.and_then(|value| value.responses.clone());
307 Ok(Box::pin(stream::iter(Vec::<Result<LLMChunk>>::new())))
308 }
309 }
310
311 let cap = Capture::default();
312 let ir = PromptIR {
313 system_text: "sys".into(),
314 segments: vec![
315 Segment::new(SegmentRole::StablePrefix, vec![Message::user("guide")]),
316 Segment::new(SegmentRole::DynamicContext, vec![Message::user("dyn")]),
317 Segment::new(SegmentRole::Conversation, vec![Message::user("ask")]),
318 ],
319 ..PromptIR::default()
320 };
321 let _ = cap
322 .chat_stream_ir(&ir, &[], None, "m", None)
323 .await
324 .expect("ir stream");
325
326 let seen = cap.seen.lock().expect("seen lock").clone();
327 let expected = ir.flatten();
328 assert_eq!(seen.len(), expected.len(), "delegates the flattened IR");
329 for (got, want) in seen.iter().zip(expected.iter()) {
330 assert_eq!(got.role, want.role);
331 assert_eq!(got.content, want.content);
332 }
333 // system + guide + dyn + ask
334 assert_eq!(seen.len(), 4);
335 assert!(matches!(seen[0].role, bamboo_domain::Role::System));
336
337 // SAFETY NET: the default also derives the Responses-API view from the IR, so
338 // a Responses provider works without overriding `chat_stream_ir`. instructions
339 // = the (trimmed) system field; input_messages = the full responses_input view
340 // (system lifted out, so it does not lead with a system message).
341 let responses = cap
342 .seen_responses
343 .lock()
344 .expect("resp lock")
345 .clone()
346 .expect("default derives Responses options from the IR");
347 assert_eq!(responses.instructions.as_deref(), Some("sys"));
348 let input = responses.input_messages.expect("input_messages derived");
349 assert_eq!(
350 input.iter().map(|m| m.content.clone()).collect::<Vec<_>>(),
351 vec!["guide".to_string(), "dyn".to_string(), "ask".to_string()],
352 "input_messages is the responses_input view: NO leading system message"
353 );
354 }
355
356 #[derive(Clone, Default)]
357 struct RecordingProvider {
358 requested_models: Arc<Mutex<Vec<String>>>,
359 requested_max_tokens: Arc<Mutex<Vec<Option<u32>>>>,
360 }
361
362 #[async_trait]
363 impl LLMProvider for RecordingProvider {
364 async fn chat_stream(
365 &self,
366 _messages: &[Message],
367 _tools: &[ToolSchema],
368 max_output_tokens: Option<u32>,
369 model: &str,
370 ) -> Result<LLMStream> {
371 if let Ok(mut models) = self.requested_models.lock() {
372 models.push(model.to_string());
373 }
374 if let Ok(mut max_tokens) = self.requested_max_tokens.lock() {
375 max_tokens.push(max_output_tokens);
376 }
377
378 Ok(Box::pin(stream::empty()))
379 }
380 }
381
382 #[tokio::test]
383 async fn chat_stream_with_options_delegates_to_chat_stream_with_same_model_and_tokens() {
384 let provider = RecordingProvider::default();
385 let options = LLMRequestOptions::default();
386
387 let mut stream = provider
388 .chat_stream_with_options(&[], &[], Some(512), "gpt-test", Some(&options))
389 .await
390 .expect("delegation should succeed");
391 assert!(stream.next().await.is_none());
392
393 assert_eq!(
394 provider
395 .requested_models
396 .lock()
397 .expect("lock poisoned")
398 .as_slice(),
399 ["gpt-test"]
400 );
401 assert_eq!(
402 provider
403 .requested_max_tokens
404 .lock()
405 .expect("lock poisoned")
406 .as_slice(),
407 [Some(512)]
408 );
409 }
410
411 #[tokio::test]
412 async fn list_models_returns_empty_by_default() {
413 let provider = RecordingProvider::default();
414 let models = provider
415 .list_models()
416 .await
417 .expect("default list_models should succeed");
418 assert!(models.is_empty());
419 }
420
421 #[test]
422 fn request_options_default_has_no_purpose() {
423 let opts = LLMRequestOptions::default();
424 assert!(opts.request_purpose.is_none());
425 }
426
427 #[test]
428 fn request_options_purpose_is_set_and_readable() {
429 let opts = LLMRequestOptions {
430 request_purpose: Some("title_generation".to_string()),
431 ..Default::default()
432 };
433 assert_eq!(opts.request_purpose.as_deref(), Some("title_generation"));
434 }
435}