bamboo_llm/provider.rs
1//! LLM provider trait and types
2//!
3//! This module defines the interface for LLM (Large Language Model) providers,
4//! enabling support for multiple LLM backends through a common trait.
5
6use crate::types::LLMChunk;
7use async_trait::async_trait;
8use bamboo_domain::Message;
9use bamboo_domain::ReasoningEffort;
10use bamboo_domain::ToolSchema;
11use futures::Stream;
12use std::pin::Pin;
13use thiserror::Error;
14
15/// Errors that can occur when working with LLM providers
16#[derive(Error, Debug)]
17pub enum LLMError {
18 /// HTTP request/response errors
19 #[error("HTTP error: {0}")]
20 Http(#[from] reqwest::Error),
21
22 /// JSON serialization/deserialization errors
23 #[error("JSON error: {0}")]
24 Json(#[from] serde_json::Error),
25
26 /// Streaming response errors
27 #[error("Stream error: {0}")]
28 Stream(String),
29
30 /// LLM API errors (rate limits, invalid requests, etc.)
31 #[error("API error: {0}")]
32 Api(String),
33
34 /// Authentication/authorization errors
35 #[error("Authentication error: {0}")]
36 Auth(String),
37
38 /// Protocol conversion errors
39 #[error("Protocol conversion error: {0}")]
40 Protocol(#[from] crate::protocol::ProtocolError),
41}
42
43/// Convenient result type for LLM operations
44pub type Result<T> = std::result::Result<T, LLMError>;
45
46/// Type alias for boxed streaming LLM responses
47pub type LLMStream = Pin<Box<dyn Stream<Item = Result<LLMChunk>> + Send>>;
48
49/// Metadata for a provider model returned by `list_model_info`.
50#[derive(Debug, Clone, PartialEq, Eq)]
51pub struct ProviderModelInfo {
52 /// Model identifier.
53 pub id: String,
54 /// Maximum context window (input + output) in tokens when known.
55 pub max_context_tokens: Option<u32>,
56 /// Maximum output/completion tokens when known.
57 pub max_output_tokens: Option<u32>,
58}
59
60impl ProviderModelInfo {
61 /// Create metadata with only model id (no token limits).
62 pub fn from_id(id: impl Into<String>) -> Self {
63 Self {
64 id: id.into(),
65 max_context_tokens: None,
66 max_output_tokens: None,
67 }
68 }
69}
70
71/// Optional request-time controls for provider calls.
72#[derive(Debug, Clone, Default)]
73pub struct ResponsesRequestOptions {
74 /// Optional top-level instructions for Responses API requests.
75 pub instructions: Option<String>,
76 /// Optional message list to serialize into the Responses API `input` array.
77 ///
78 /// When omitted, providers fall back to the generic `messages` slice passed
79 /// to `chat_stream_with_options`. This lets the engine provide a
80 /// Responses-specific input view (for example, without a duplicated stable
81 /// system message) while preserving backward compatibility for non-Responses
82 /// callers and providers.
83 pub input_messages: Option<Vec<Message>>,
84 /// Optional reasoning summary control for Responses API requests
85 /// (e.g. "auto", "concise", "detailed").
86 pub reasoning_summary: Option<String>,
87 /// Optional include list for Responses API requests.
88 pub include: Option<Vec<String>>,
89 /// Whether Responses API should store the response server-side.
90 pub store: Option<bool>,
91 /// Optional continuation handle for stateful Responses API turns.
92 pub previous_response_id: Option<String>,
93 /// Optional truncation mode for Responses API requests
94 /// (e.g. "auto", "disabled").
95 pub truncation: Option<String>,
96 /// Optional text verbosity for Responses API requests
97 /// (e.g. "low", "medium", "high").
98 pub text_verbosity: Option<String>,
99}
100
101/// Optional request-time controls for provider calls.
102#[derive(Debug, Clone, Default)]
103pub struct LLMRequestOptions {
104 /// Session identifier used for request-scoped logging correlation.
105 pub session_id: Option<String>,
106 /// Override reasoning effort for this request.
107 pub reasoning_effort: Option<ReasoningEffort>,
108 /// Request provider-side parallel tool call planning when supported.
109 ///
110 /// - OpenAI/Copilot: maps to `parallel_tool_calls`
111 /// - Anthropic: maps to `tool_choice.disable_parallel_tool_use` (inverse)
112 pub parallel_tool_calls: Option<bool>,
113 /// Responses API specific overrides.
114 pub responses: Option<ResponsesRequestOptions>,
115 /// Purpose of this request for observability (e.g., "agent_loop", "task_evaluation").
116 pub request_purpose: Option<String>,
117 /// Provider-agnostic prompt-cache plan describing the stable, cacheable
118 /// prefix of this request. Providers render it in their own dialect
119 /// (Anthropic `cache_control` breakpoints; OpenAI/Gemini rely on the stable
120 /// prefix automatically). `None` means "no explicit cache hints".
121 pub cache: Option<crate::cache::PromptCachePlan>,
122}
123
124/// Canonical, provider-facing prompt structure: the engine assembles these four
125/// layers ONCE, and each provider adapter renders them into its own wire format
126/// (system field + message array + cache breakpoints) instead of re-deriving the
127/// structure from a pre-flattened message list. This is what lets every provider
128/// be a pure adapter — the prompt-assembly logic lives in Bamboo, not duplicated
129/// across providers.
130///
131/// Concatenation order is fixed and defines the message layout:
132/// `[system(stable_instructions)] + stable_prefix_messages + dynamic_context_messages + conversation_messages`.
133///
134/// The lane boundaries are also the natural cache breakpoints: everything up to
135/// (and including) `stable_prefix_messages` is the stable, cacheable prefix;
136/// `dynamic_context_messages` onward changes per round.
137#[derive(Debug, Clone, Default)]
138pub struct PromptLanes {
139 /// Static system instructions — the cacheable base. Rendered into the
140 /// provider's dedicated system field, NOT the message array.
141 pub stable_instructions: String,
142 /// Session-stable context messages (tool guide, connected MCP servers'
143 /// guidance, workspace, env, skills): fixed positions that change rarely. The
144 /// stable cache prefix ends after these.
145 pub stable_prefix_messages: Vec<Message>,
146 /// Per-round dynamic context (task snapshot, recalled memory, conversation
147 /// summary): changes turn to turn, so it sits AFTER the cache breakpoint.
148 pub dynamic_context_messages: Vec<Message>,
149 /// The actual user / assistant / tool conversation history.
150 pub conversation_messages: Vec<Message>,
151}
152
153impl PromptLanes {
154 /// Flatten the lanes into one message list in canonical order — the exact
155 /// shape a provider that has NOT yet been migrated to consume lanes still
156 /// expects, so the default trait path stays byte-identical to today.
157 pub fn flatten(&self) -> Vec<Message> {
158 let mut messages = Vec::with_capacity(
159 1 + self.stable_prefix_messages.len()
160 + self.dynamic_context_messages.len()
161 + self.conversation_messages.len(),
162 );
163 if !self.stable_instructions.trim().is_empty() {
164 messages.push(Message::system(self.stable_instructions.trim().to_string()));
165 }
166 messages.extend(self.stable_prefix_messages.iter().cloned());
167 messages.extend(self.dynamic_context_messages.iter().cloned());
168 messages.extend(self.conversation_messages.iter().cloned());
169 messages
170 }
171}
172
173/// Trait for LLM provider implementations
174///
175/// This trait defines the interface that all LLM providers must implement
176/// to work with Bamboo's agent system. Providers handle communication with
177/// specific LLM services (OpenAI, Anthropic, local models, etc.).
178///
179/// # Design Principle
180///
181/// The `model` parameter is **required** in `chat_stream`, not optional.
182/// This ensures that the calling code explicitly specifies which model to use,
183/// preventing accidental use of unintended models and making model selection
184/// explicit and auditable.
185///
186/// # Example
187///
188/// ```ignore
189/// use bamboo_agent::agent::llm::provider::LLMProvider;
190///
191/// async fn use_provider(provider: &dyn LLMProvider) {
192/// let stream = provider.chat_stream(
193/// &messages,
194/// &tools,
195/// Some(4096),
196/// "claude-sonnet-4-6", // Model is required
197/// ).await?;
198/// }
199/// ```
200#[async_trait]
201pub trait LLMProvider: Send + Sync {
202 /// Stream chat completion from the LLM
203 ///
204 /// This is the primary method for interacting with LLMs, returning
205 /// a stream of response chunks that can be processed incrementally.
206 ///
207 /// # Arguments
208 ///
209 /// * `messages` - Conversation history and current prompt
210 /// * `tools` - Available tools the LLM can call
211 /// * `max_output_tokens` - Optional limit on response length
212 /// * `model` - **Required** model identifier (e.g., "claude-sonnet-4-6")
213 ///
214 /// # Returns
215 ///
216 /// A stream of `LLMChunk` items containing partial responses
217 ///
218 /// # Errors
219 ///
220 /// Returns `LLMError` on network failures, API errors, or invalid requests
221 async fn chat_stream(
222 &self,
223 messages: &[Message],
224 tools: &[ToolSchema],
225 max_output_tokens: Option<u32>,
226 model: &str,
227 ) -> Result<LLMStream>;
228
229 /// Stream chat completion with optional request-level controls.
230 ///
231 /// Default implementation preserves backward compatibility by delegating to
232 /// [`LLMProvider::chat_stream`].
233 async fn chat_stream_with_options(
234 &self,
235 messages: &[Message],
236 tools: &[ToolSchema],
237 max_output_tokens: Option<u32>,
238 model: &str,
239 _options: Option<&LLMRequestOptions>,
240 ) -> Result<LLMStream> {
241 self.chat_stream(messages, tools, max_output_tokens, model)
242 .await
243 }
244
245 /// Stream a completion from the canonical [`PromptLanes`] contract — the
246 /// structure-preserving entry point.
247 ///
248 /// The provider receives the prompt LAYERS (static system, stable prefix,
249 /// dynamic context, conversation) and is expected to render them into its own
250 /// dialect: place the system block in its system field and the cache
251 /// breakpoint at the structural stable↔dynamic boundary, rather than
252 /// re-deriving both from a flattened message list.
253 ///
254 /// The default implementation flattens the lanes ([`PromptLanes::flatten`])
255 /// and delegates to [`LLMProvider::chat_stream_with_options`], so a provider
256 /// that has not yet been migrated produces exactly the request it does today.
257 async fn chat_stream_lanes(
258 &self,
259 lanes: &PromptLanes,
260 tools: &[ToolSchema],
261 max_output_tokens: Option<u32>,
262 model: &str,
263 options: Option<&LLMRequestOptions>,
264 ) -> Result<LLMStream> {
265 let messages = lanes.flatten();
266 self.chat_stream_with_options(&messages, tools, max_output_tokens, model, options)
267 .await
268 }
269
270 /// Lists available models from this provider
271 ///
272 /// Returns a list of model identifiers that can be used with `chat_stream`.
273 /// Default implementation returns an empty list.
274 async fn list_models(&self) -> Result<Vec<String>> {
275 // Default implementation returns empty list
276 Ok(vec![])
277 }
278
279 /// Lists available models with optional token limit metadata.
280 ///
281 /// Default implementation preserves backward compatibility by adapting
282 /// `list_models()` output into metadata entries without limits.
283 async fn list_model_info(&self) -> Result<Vec<ProviderModelInfo>> {
284 Ok(self
285 .list_models()
286 .await?
287 .into_iter()
288 .map(ProviderModelInfo::from_id)
289 .collect())
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use std::sync::{Arc, Mutex};
296
297 use async_trait::async_trait;
298 use futures::{stream, StreamExt};
299
300 use super::*;
301
302 #[test]
303 fn prompt_lanes_flatten_preserves_canonical_order() {
304 let lanes = PromptLanes {
305 stable_instructions: " base system ".to_string(),
306 stable_prefix_messages: vec![Message::user("tool-guide")],
307 dynamic_context_messages: vec![Message::user("task-snapshot")],
308 conversation_messages: vec![Message::user("real ask")],
309 };
310 let flat = lanes.flatten();
311 assert_eq!(flat.len(), 4);
312 assert!(matches!(flat[0].role, bamboo_domain::Role::System));
313 assert_eq!(flat[0].content, "base system"); // trimmed
314 assert_eq!(flat[1].content, "tool-guide");
315 assert_eq!(flat[2].content, "task-snapshot");
316 assert_eq!(flat[3].content, "real ask");
317 }
318
319 #[tokio::test]
320 async fn chat_stream_lanes_default_flattens_and_delegates() {
321 // A provider that captures whatever message list it is handed.
322 #[derive(Default)]
323 struct Capture {
324 seen: Arc<Mutex<Vec<Message>>>,
325 }
326 #[async_trait]
327 impl LLMProvider for Capture {
328 async fn chat_stream(
329 &self,
330 _m: &[Message],
331 _t: &[ToolSchema],
332 _mt: Option<u32>,
333 _model: &str,
334 ) -> Result<LLMStream> {
335 unreachable!("default chat_stream_lanes must route via chat_stream_with_options")
336 }
337 async fn chat_stream_with_options(
338 &self,
339 messages: &[Message],
340 _t: &[ToolSchema],
341 _mt: Option<u32>,
342 _model: &str,
343 _o: Option<&LLMRequestOptions>,
344 ) -> Result<LLMStream> {
345 *self.seen.lock().expect("seen lock") = messages.to_vec();
346 Ok(Box::pin(stream::iter(Vec::<Result<LLMChunk>>::new())))
347 }
348 }
349
350 let cap = Capture::default();
351 let lanes = PromptLanes {
352 stable_instructions: "sys".into(),
353 stable_prefix_messages: vec![Message::user("guide")],
354 dynamic_context_messages: vec![Message::user("dyn")],
355 conversation_messages: vec![Message::user("ask")],
356 };
357 let _ = cap
358 .chat_stream_lanes(&lanes, &[], None, "m", None)
359 .await
360 .expect("lanes stream");
361
362 let seen = cap.seen.lock().expect("seen lock").clone();
363 let expected = lanes.flatten();
364 assert_eq!(seen.len(), expected.len(), "delegates the flattened lanes");
365 for (got, want) in seen.iter().zip(expected.iter()) {
366 assert_eq!(got.role, want.role);
367 assert_eq!(got.content, want.content);
368 }
369 // system + guide + dyn + ask
370 assert_eq!(seen.len(), 4);
371 assert!(matches!(seen[0].role, bamboo_domain::Role::System));
372 }
373
374 #[test]
375 fn prompt_lanes_flatten_omits_empty_system() {
376 let lanes = PromptLanes {
377 stable_instructions: " ".to_string(),
378 conversation_messages: vec![Message::user("hi")],
379 ..PromptLanes::default()
380 };
381 let flat = lanes.flatten();
382 assert_eq!(flat.len(), 1);
383 assert!(matches!(flat[0].role, bamboo_domain::Role::User));
384 }
385
386 #[derive(Clone, Default)]
387 struct RecordingProvider {
388 requested_models: Arc<Mutex<Vec<String>>>,
389 requested_max_tokens: Arc<Mutex<Vec<Option<u32>>>>,
390 }
391
392 #[async_trait]
393 impl LLMProvider for RecordingProvider {
394 async fn chat_stream(
395 &self,
396 _messages: &[Message],
397 _tools: &[ToolSchema],
398 max_output_tokens: Option<u32>,
399 model: &str,
400 ) -> Result<LLMStream> {
401 if let Ok(mut models) = self.requested_models.lock() {
402 models.push(model.to_string());
403 }
404 if let Ok(mut max_tokens) = self.requested_max_tokens.lock() {
405 max_tokens.push(max_output_tokens);
406 }
407
408 Ok(Box::pin(stream::empty()))
409 }
410 }
411
412 #[tokio::test]
413 async fn chat_stream_with_options_delegates_to_chat_stream_with_same_model_and_tokens() {
414 let provider = RecordingProvider::default();
415 let options = LLMRequestOptions::default();
416
417 let mut stream = provider
418 .chat_stream_with_options(&[], &[], Some(512), "gpt-test", Some(&options))
419 .await
420 .expect("delegation should succeed");
421 assert!(stream.next().await.is_none());
422
423 assert_eq!(
424 provider
425 .requested_models
426 .lock()
427 .expect("lock poisoned")
428 .as_slice(),
429 ["gpt-test"]
430 );
431 assert_eq!(
432 provider
433 .requested_max_tokens
434 .lock()
435 .expect("lock poisoned")
436 .as_slice(),
437 [Some(512)]
438 );
439 }
440
441 #[tokio::test]
442 async fn list_models_returns_empty_by_default() {
443 let provider = RecordingProvider::default();
444 let models = provider
445 .list_models()
446 .await
447 .expect("default list_models should succeed");
448 assert!(models.is_empty());
449 }
450
451 #[test]
452 fn request_options_default_has_no_purpose() {
453 let opts = LLMRequestOptions::default();
454 assert!(opts.request_purpose.is_none());
455 }
456
457 #[test]
458 fn request_options_purpose_is_set_and_readable() {
459 let opts = LLMRequestOptions {
460 request_purpose: Some("title_generation".to_string()),
461 ..Default::default()
462 };
463 assert_eq!(opts.request_purpose.as_deref(), Some("title_generation"));
464 }
465}