Skip to main content

bamboo_llm/
provider.rs

1//! LLM provider trait and types
2//!
3//! This module defines the interface for LLM (Large Language Model) providers,
4//! enabling support for multiple LLM backends through a common trait.
5
6use crate::types::LLMChunk;
7use async_trait::async_trait;
8use bamboo_domain::Message;
9use bamboo_domain::ReasoningEffort;
10use bamboo_domain::ToolSchema;
11use futures::Stream;
12use std::pin::Pin;
13use thiserror::Error;
14
15/// Errors that can occur when working with LLM providers
16#[derive(Error, Debug)]
17pub enum LLMError {
18    /// HTTP request/response errors
19    #[error("HTTP error: {0}")]
20    Http(#[from] reqwest::Error),
21
22    /// JSON serialization/deserialization errors
23    #[error("JSON error: {0}")]
24    Json(#[from] serde_json::Error),
25
26    /// Streaming response errors
27    #[error("Stream error: {0}")]
28    Stream(String),
29
30    /// LLM API errors (rate limits, invalid requests, etc.)
31    #[error("API error: {0}")]
32    Api(String),
33
34    /// Authentication/authorization errors
35    #[error("Authentication error: {0}")]
36    Auth(String),
37
38    /// Protocol conversion errors
39    #[error("Protocol conversion error: {0}")]
40    Protocol(#[from] crate::protocol::ProtocolError),
41}
42
43/// Convenient result type for LLM operations
44pub type Result<T> = std::result::Result<T, LLMError>;
45
46/// Type alias for boxed streaming LLM responses
47pub type LLMStream = Pin<Box<dyn Stream<Item = Result<LLMChunk>> + Send>>;
48
49/// Metadata for a provider model returned by `list_model_info`.
50#[derive(Debug, Clone, PartialEq, Eq)]
51pub struct ProviderModelInfo {
52    /// Model identifier.
53    pub id: String,
54    /// Maximum context window (input + output) in tokens when known.
55    pub max_context_tokens: Option<u32>,
56    /// Maximum output/completion tokens when known.
57    pub max_output_tokens: Option<u32>,
58}
59
60impl ProviderModelInfo {
61    /// Create metadata with only model id (no token limits).
62    pub fn from_id(id: impl Into<String>) -> Self {
63        Self {
64            id: id.into(),
65            max_context_tokens: None,
66            max_output_tokens: None,
67        }
68    }
69}
70
71/// Optional request-time controls for provider calls.
72#[derive(Debug, Clone, Default)]
73pub struct ResponsesRequestOptions {
74    /// Optional top-level instructions for Responses API requests.
75    pub instructions: Option<String>,
76    /// Optional message list to serialize into the Responses API `input` array.
77    ///
78    /// When omitted, providers fall back to the generic `messages` slice passed
79    /// to `chat_stream_with_options`. This lets the engine provide a
80    /// Responses-specific input view (for example, without a duplicated stable
81    /// system message) while preserving backward compatibility for non-Responses
82    /// callers and providers.
83    pub input_messages: Option<Vec<Message>>,
84    /// Optional reasoning summary control for Responses API requests
85    /// (e.g. "auto", "concise", "detailed").
86    pub reasoning_summary: Option<String>,
87    /// Optional include list for Responses API requests.
88    pub include: Option<Vec<String>>,
89    /// Whether Responses API should store the response server-side.
90    pub store: Option<bool>,
91    /// Optional continuation handle for stateful Responses API turns.
92    pub previous_response_id: Option<String>,
93    /// Optional truncation mode for Responses API requests
94    /// (e.g. "auto", "disabled").
95    pub truncation: Option<String>,
96    /// Optional text verbosity for Responses API requests
97    /// (e.g. "low", "medium", "high").
98    pub text_verbosity: Option<String>,
99}
100
101/// Optional request-time controls for provider calls.
102#[derive(Debug, Clone, Default)]
103pub struct LLMRequestOptions {
104    /// Session identifier used for request-scoped logging correlation.
105    pub session_id: Option<String>,
106    /// Override reasoning effort for this request.
107    pub reasoning_effort: Option<ReasoningEffort>,
108    /// Request provider-side parallel tool call planning when supported.
109    ///
110    /// - OpenAI/Copilot: maps to `parallel_tool_calls`
111    /// - Anthropic: maps to `tool_choice.disable_parallel_tool_use` (inverse)
112    pub parallel_tool_calls: Option<bool>,
113    /// Responses API specific overrides.
114    pub responses: Option<ResponsesRequestOptions>,
115    /// Purpose of this request for observability (e.g., "agent_loop", "task_evaluation").
116    pub request_purpose: Option<String>,
117    /// Provider-agnostic prompt-cache plan describing the stable, cacheable
118    /// prefix of this request. Providers render it in their own dialect
119    /// (Anthropic `cache_control` breakpoints; OpenAI/Gemini rely on the stable
120    /// prefix automatically). `None` means "no explicit cache hints".
121    pub cache: Option<crate::cache::PromptCachePlan>,
122}
123
124/// Canonical, provider-facing prompt structure: the engine assembles these four
125/// layers ONCE, and each provider adapter renders them into its own wire format
126/// (system field + message array + cache breakpoints) instead of re-deriving the
127/// structure from a pre-flattened message list. This is what lets every provider
128/// be a pure adapter — the prompt-assembly logic lives in Bamboo, not duplicated
129/// across providers.
130///
131/// Concatenation order is fixed and defines the message layout:
132/// `[system(stable_instructions)] + stable_prefix_messages + dynamic_context_messages + conversation_messages`.
133///
134/// The lane boundaries are also the natural cache breakpoints: everything up to
135/// (and including) `stable_prefix_messages` is the stable, cacheable prefix;
136/// `dynamic_context_messages` onward changes per round.
137#[derive(Debug, Clone, Default)]
138pub struct PromptLanes {
139    /// Static system instructions — the cacheable base. Rendered into the
140    /// provider's dedicated system field, NOT the message array.
141    pub stable_instructions: String,
142    /// Session-stable context messages (tool guide, connected MCP servers'
143    /// guidance, workspace, env, skills): fixed positions that change rarely. The
144    /// stable cache prefix ends after these.
145    pub stable_prefix_messages: Vec<Message>,
146    /// Per-round dynamic context (task snapshot, recalled memory, conversation
147    /// summary): changes turn to turn, so it sits AFTER the cache breakpoint.
148    pub dynamic_context_messages: Vec<Message>,
149    /// The actual user / assistant / tool conversation history.
150    pub conversation_messages: Vec<Message>,
151}
152
153impl PromptLanes {
154    /// Flatten the lanes into one message list in canonical order — the exact
155    /// shape a provider that has NOT yet been migrated to consume lanes still
156    /// expects, so the default trait path stays byte-identical to today.
157    pub fn flatten(&self) -> Vec<Message> {
158        let mut messages = Vec::with_capacity(
159            1 + self.stable_prefix_messages.len()
160                + self.dynamic_context_messages.len()
161                + self.conversation_messages.len(),
162        );
163        if !self.stable_instructions.trim().is_empty() {
164            messages.push(Message::system(self.stable_instructions.trim().to_string()));
165        }
166        messages.extend(self.stable_prefix_messages.iter().cloned());
167        messages.extend(self.dynamic_context_messages.iter().cloned());
168        messages.extend(self.conversation_messages.iter().cloned());
169        messages
170    }
171}
172
173/// Trait for LLM provider implementations
174///
175/// This trait defines the interface that all LLM providers must implement
176/// to work with Bamboo's agent system. Providers handle communication with
177/// specific LLM services (OpenAI, Anthropic, local models, etc.).
178///
179/// # Design Principle
180///
181/// The `model` parameter is **required** in `chat_stream`, not optional.
182/// This ensures that the calling code explicitly specifies which model to use,
183/// preventing accidental use of unintended models and making model selection
184/// explicit and auditable.
185///
186/// # Example
187///
188/// ```ignore
189/// use bamboo_agent::agent::llm::provider::LLMProvider;
190///
191/// async fn use_provider(provider: &dyn LLMProvider) {
192///     let stream = provider.chat_stream(
193///         &messages,
194///         &tools,
195///         Some(4096),
196///         "claude-sonnet-4-6", // Model is required
197///     ).await?;
198/// }
199/// ```
200#[async_trait]
201pub trait LLMProvider: Send + Sync {
202    /// Stream chat completion from the LLM
203    ///
204    /// This is the primary method for interacting with LLMs, returning
205    /// a stream of response chunks that can be processed incrementally.
206    ///
207    /// # Arguments
208    ///
209    /// * `messages` - Conversation history and current prompt
210    /// * `tools` - Available tools the LLM can call
211    /// * `max_output_tokens` - Optional limit on response length
212    /// * `model` - **Required** model identifier (e.g., "claude-sonnet-4-6")
213    ///
214    /// # Returns
215    ///
216    /// A stream of `LLMChunk` items containing partial responses
217    ///
218    /// # Errors
219    ///
220    /// Returns `LLMError` on network failures, API errors, or invalid requests
221    async fn chat_stream(
222        &self,
223        messages: &[Message],
224        tools: &[ToolSchema],
225        max_output_tokens: Option<u32>,
226        model: &str,
227    ) -> Result<LLMStream>;
228
229    /// Stream chat completion with optional request-level controls.
230    ///
231    /// Default implementation preserves backward compatibility by delegating to
232    /// [`LLMProvider::chat_stream`].
233    async fn chat_stream_with_options(
234        &self,
235        messages: &[Message],
236        tools: &[ToolSchema],
237        max_output_tokens: Option<u32>,
238        model: &str,
239        _options: Option<&LLMRequestOptions>,
240    ) -> Result<LLMStream> {
241        self.chat_stream(messages, tools, max_output_tokens, model)
242            .await
243    }
244
245    /// Stream a completion from the canonical [`PromptLanes`] contract — the
246    /// structure-preserving entry point.
247    ///
248    /// The provider receives the prompt LAYERS (static system, stable prefix,
249    /// dynamic context, conversation) and is expected to render them into its own
250    /// dialect: place the system block in its system field and the cache
251    /// breakpoint at the structural stable↔dynamic boundary, rather than
252    /// re-deriving both from a flattened message list.
253    ///
254    /// The default implementation flattens the lanes ([`PromptLanes::flatten`])
255    /// and delegates to [`LLMProvider::chat_stream_with_options`], so a provider
256    /// that has not yet been migrated produces exactly the request it does today.
257    async fn chat_stream_lanes(
258        &self,
259        lanes: &PromptLanes,
260        tools: &[ToolSchema],
261        max_output_tokens: Option<u32>,
262        model: &str,
263        options: Option<&LLMRequestOptions>,
264    ) -> Result<LLMStream> {
265        let messages = lanes.flatten();
266        self.chat_stream_with_options(&messages, tools, max_output_tokens, model, options)
267            .await
268    }
269
270    /// Lists available models from this provider
271    ///
272    /// Returns a list of model identifiers that can be used with `chat_stream`.
273    /// Default implementation returns an empty list.
274    async fn list_models(&self) -> Result<Vec<String>> {
275        // Default implementation returns empty list
276        Ok(vec![])
277    }
278
279    /// Lists available models with optional token limit metadata.
280    ///
281    /// Default implementation preserves backward compatibility by adapting
282    /// `list_models()` output into metadata entries without limits.
283    async fn list_model_info(&self) -> Result<Vec<ProviderModelInfo>> {
284        Ok(self
285            .list_models()
286            .await?
287            .into_iter()
288            .map(ProviderModelInfo::from_id)
289            .collect())
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use std::sync::{Arc, Mutex};
296
297    use async_trait::async_trait;
298    use futures::{stream, StreamExt};
299
300    use super::*;
301
302    #[test]
303    fn prompt_lanes_flatten_preserves_canonical_order() {
304        let lanes = PromptLanes {
305            stable_instructions: "  base system  ".to_string(),
306            stable_prefix_messages: vec![Message::user("tool-guide")],
307            dynamic_context_messages: vec![Message::user("task-snapshot")],
308            conversation_messages: vec![Message::user("real ask")],
309        };
310        let flat = lanes.flatten();
311        assert_eq!(flat.len(), 4);
312        assert!(matches!(flat[0].role, bamboo_domain::Role::System));
313        assert_eq!(flat[0].content, "base system"); // trimmed
314        assert_eq!(flat[1].content, "tool-guide");
315        assert_eq!(flat[2].content, "task-snapshot");
316        assert_eq!(flat[3].content, "real ask");
317    }
318
319    #[tokio::test]
320    async fn chat_stream_lanes_default_flattens_and_delegates() {
321        // A provider that captures whatever message list it is handed.
322        #[derive(Default)]
323        struct Capture {
324            seen: Arc<Mutex<Vec<Message>>>,
325        }
326        #[async_trait]
327        impl LLMProvider for Capture {
328            async fn chat_stream(
329                &self,
330                _m: &[Message],
331                _t: &[ToolSchema],
332                _mt: Option<u32>,
333                _model: &str,
334            ) -> Result<LLMStream> {
335                unreachable!("default chat_stream_lanes must route via chat_stream_with_options")
336            }
337            async fn chat_stream_with_options(
338                &self,
339                messages: &[Message],
340                _t: &[ToolSchema],
341                _mt: Option<u32>,
342                _model: &str,
343                _o: Option<&LLMRequestOptions>,
344            ) -> Result<LLMStream> {
345                *self.seen.lock().expect("seen lock") = messages.to_vec();
346                Ok(Box::pin(stream::iter(Vec::<Result<LLMChunk>>::new())))
347            }
348        }
349
350        let cap = Capture::default();
351        let lanes = PromptLanes {
352            stable_instructions: "sys".into(),
353            stable_prefix_messages: vec![Message::user("guide")],
354            dynamic_context_messages: vec![Message::user("dyn")],
355            conversation_messages: vec![Message::user("ask")],
356        };
357        let _ = cap
358            .chat_stream_lanes(&lanes, &[], None, "m", None)
359            .await
360            .expect("lanes stream");
361
362        let seen = cap.seen.lock().expect("seen lock").clone();
363        let expected = lanes.flatten();
364        assert_eq!(seen.len(), expected.len(), "delegates the flattened lanes");
365        for (got, want) in seen.iter().zip(expected.iter()) {
366            assert_eq!(got.role, want.role);
367            assert_eq!(got.content, want.content);
368        }
369        // system + guide + dyn + ask
370        assert_eq!(seen.len(), 4);
371        assert!(matches!(seen[0].role, bamboo_domain::Role::System));
372    }
373
374    #[test]
375    fn prompt_lanes_flatten_omits_empty_system() {
376        let lanes = PromptLanes {
377            stable_instructions: "   ".to_string(),
378            conversation_messages: vec![Message::user("hi")],
379            ..PromptLanes::default()
380        };
381        let flat = lanes.flatten();
382        assert_eq!(flat.len(), 1);
383        assert!(matches!(flat[0].role, bamboo_domain::Role::User));
384    }
385
386    #[derive(Clone, Default)]
387    struct RecordingProvider {
388        requested_models: Arc<Mutex<Vec<String>>>,
389        requested_max_tokens: Arc<Mutex<Vec<Option<u32>>>>,
390    }
391
392    #[async_trait]
393    impl LLMProvider for RecordingProvider {
394        async fn chat_stream(
395            &self,
396            _messages: &[Message],
397            _tools: &[ToolSchema],
398            max_output_tokens: Option<u32>,
399            model: &str,
400        ) -> Result<LLMStream> {
401            if let Ok(mut models) = self.requested_models.lock() {
402                models.push(model.to_string());
403            }
404            if let Ok(mut max_tokens) = self.requested_max_tokens.lock() {
405                max_tokens.push(max_output_tokens);
406            }
407
408            Ok(Box::pin(stream::empty()))
409        }
410    }
411
412    #[tokio::test]
413    async fn chat_stream_with_options_delegates_to_chat_stream_with_same_model_and_tokens() {
414        let provider = RecordingProvider::default();
415        let options = LLMRequestOptions::default();
416
417        let mut stream = provider
418            .chat_stream_with_options(&[], &[], Some(512), "gpt-test", Some(&options))
419            .await
420            .expect("delegation should succeed");
421        assert!(stream.next().await.is_none());
422
423        assert_eq!(
424            provider
425                .requested_models
426                .lock()
427                .expect("lock poisoned")
428                .as_slice(),
429            ["gpt-test"]
430        );
431        assert_eq!(
432            provider
433                .requested_max_tokens
434                .lock()
435                .expect("lock poisoned")
436                .as_slice(),
437            [Some(512)]
438        );
439    }
440
441    #[tokio::test]
442    async fn list_models_returns_empty_by_default() {
443        let provider = RecordingProvider::default();
444        let models = provider
445            .list_models()
446            .await
447            .expect("default list_models should succeed");
448        assert!(models.is_empty());
449    }
450
451    #[test]
452    fn request_options_default_has_no_purpose() {
453        let opts = LLMRequestOptions::default();
454        assert!(opts.request_purpose.is_none());
455    }
456
457    #[test]
458    fn request_options_purpose_is_set_and_readable() {
459        let opts = LLMRequestOptions {
460            request_purpose: Some("title_generation".to_string()),
461            ..Default::default()
462        };
463        assert_eq!(opts.request_purpose.as_deref(), Some("title_generation"));
464    }
465}