Skip to main content

bamboo_llm/
provider.rs

1//! LLM provider trait and types
2//!
3//! This module defines the interface for LLM (Large Language Model) providers,
4//! enabling support for multiple LLM backends through a common trait.
5
6use crate::prompt_ir::PromptIR;
7use crate::types::LLMChunk;
8use async_trait::async_trait;
9use bamboo_domain::Message;
10use bamboo_domain::ReasoningEffort;
11use bamboo_domain::ToolSchema;
12use futures::Stream;
13use std::pin::Pin;
14use thiserror::Error;
15
16/// Errors that can occur when working with LLM providers
17#[derive(Error, Debug)]
18pub enum LLMError {
19    /// HTTP request/response errors
20    #[error("HTTP error: {0}")]
21    Http(#[from] reqwest::Error),
22
23    /// JSON serialization/deserialization errors
24    #[error("JSON error: {0}")]
25    Json(#[from] serde_json::Error),
26
27    /// Streaming response errors
28    #[error("Stream error: {0}")]
29    Stream(String),
30
31    /// LLM API errors (rate limits, invalid requests, etc.)
32    #[error("API error: {0}")]
33    Api(String),
34
35    /// Authentication/authorization errors
36    #[error("Authentication error: {0}")]
37    Auth(String),
38
39    /// Protocol conversion errors
40    #[error("Protocol conversion error: {0}")]
41    Protocol(#[from] crate::protocol::ProtocolError),
42}
43
44/// Convenient result type for LLM operations
45pub type Result<T> = std::result::Result<T, LLMError>;
46
47/// Type alias for boxed streaming LLM responses
48pub type LLMStream = Pin<Box<dyn Stream<Item = Result<LLMChunk>> + Send>>;
49
50/// Metadata for a provider model returned by `list_model_info`.
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub struct ProviderModelInfo {
53    /// Model identifier.
54    pub id: String,
55    /// Maximum context window (input + output) in tokens when known.
56    pub max_context_tokens: Option<u32>,
57    /// Maximum output/completion tokens when known.
58    pub max_output_tokens: Option<u32>,
59}
60
61impl ProviderModelInfo {
62    /// Create metadata with only model id (no token limits).
63    pub fn from_id(id: impl Into<String>) -> Self {
64        Self {
65            id: id.into(),
66            max_context_tokens: None,
67            max_output_tokens: None,
68        }
69    }
70}
71
72/// Optional request-time controls for provider calls.
73#[derive(Debug, Clone, Default)]
74pub struct ResponsesRequestOptions {
75    /// Optional top-level instructions for Responses API requests.
76    pub instructions: Option<String>,
77    /// Optional message list to serialize into the Responses API `input` array.
78    ///
79    /// When omitted, providers fall back to the generic `messages` slice passed
80    /// to `chat_stream_with_options`. This lets the engine provide a
81    /// Responses-specific input view (for example, without a duplicated stable
82    /// system message) while preserving backward compatibility for non-Responses
83    /// callers and providers.
84    pub input_messages: Option<Vec<Message>>,
85    /// Optional reasoning summary control for Responses API requests
86    /// (e.g. "auto", "concise", "detailed").
87    pub reasoning_summary: Option<String>,
88    /// Optional include list for Responses API requests.
89    pub include: Option<Vec<String>>,
90    /// Whether Responses API should store the response server-side.
91    pub store: Option<bool>,
92    /// Optional continuation handle for stateful Responses API turns.
93    pub previous_response_id: Option<String>,
94    /// Optional truncation mode for Responses API requests
95    /// (e.g. "auto", "disabled").
96    pub truncation: Option<String>,
97    /// Optional text verbosity for Responses API requests
98    /// (e.g. "low", "medium", "high").
99    pub text_verbosity: Option<String>,
100}
101
102/// Optional request-time controls for provider calls.
103#[derive(Debug, Clone, Default)]
104pub struct LLMRequestOptions {
105    /// Session identifier used for request-scoped logging correlation.
106    pub session_id: Option<String>,
107    /// Override reasoning effort for this request.
108    pub reasoning_effort: Option<ReasoningEffort>,
109    /// Request provider-side parallel tool call planning when supported.
110    ///
111    /// - OpenAI/Copilot: maps to `parallel_tool_calls`
112    /// - Anthropic: maps to `tool_choice.disable_parallel_tool_use` (inverse)
113    pub parallel_tool_calls: Option<bool>,
114    /// Responses API specific overrides.
115    pub responses: Option<ResponsesRequestOptions>,
116    /// Purpose of this request for observability (e.g., "agent_loop", "task_evaluation").
117    pub request_purpose: Option<String>,
118    /// Provider-agnostic prompt-cache plan describing the stable, cacheable
119    /// prefix of this request. Providers render it in their own dialect
120    /// (Anthropic `cache_control` breakpoints; OpenAI/Gemini rely on the stable
121    /// prefix automatically). `None` means "no explicit cache hints".
122    pub cache: Option<crate::cache::PromptCachePlan>,
123}
124
125/// Trait for LLM provider implementations
126///
127/// This trait defines the interface that all LLM providers must implement
128/// to work with Bamboo's agent system. Providers handle communication with
129/// specific LLM services (OpenAI, Anthropic, local models, etc.).
130///
131/// # Design Principle
132///
133/// The `model` parameter is **required** in `chat_stream`, not optional.
134/// This ensures that the calling code explicitly specifies which model to use,
135/// preventing accidental use of unintended models and making model selection
136/// explicit and auditable.
137///
138/// # Example
139///
140/// ```ignore
141/// use bamboo_agent::agent::llm::provider::LLMProvider;
142///
143/// async fn use_provider(provider: &dyn LLMProvider) {
144///     let stream = provider.chat_stream(
145///         &messages,
146///         &tools,
147///         Some(4096),
148///         "claude-sonnet-4-6", // Model is required
149///     ).await?;
150/// }
151/// ```
152#[async_trait]
153pub trait LLMProvider: Send + Sync {
154    /// Stream chat completion from the LLM
155    ///
156    /// This is the primary method for interacting with LLMs, returning
157    /// a stream of response chunks that can be processed incrementally.
158    ///
159    /// # Arguments
160    ///
161    /// * `messages` - Conversation history and current prompt
162    /// * `tools` - Available tools the LLM can call
163    /// * `max_output_tokens` - Optional limit on response length
164    /// * `model` - **Required** model identifier (e.g., "claude-sonnet-4-6")
165    ///
166    /// # Returns
167    ///
168    /// A stream of `LLMChunk` items containing partial responses
169    ///
170    /// # Errors
171    ///
172    /// Returns `LLMError` on network failures, API errors, or invalid requests
173    async fn chat_stream(
174        &self,
175        messages: &[Message],
176        tools: &[ToolSchema],
177        max_output_tokens: Option<u32>,
178        model: &str,
179    ) -> Result<LLMStream>;
180
181    /// Stream chat completion with optional request-level controls.
182    ///
183    /// Default implementation preserves backward compatibility by delegating to
184    /// [`LLMProvider::chat_stream`].
185    async fn chat_stream_with_options(
186        &self,
187        messages: &[Message],
188        tools: &[ToolSchema],
189        max_output_tokens: Option<u32>,
190        model: &str,
191        _options: Option<&LLMRequestOptions>,
192    ) -> Result<LLMStream> {
193        self.chat_stream(messages, tools, max_output_tokens, model)
194            .await
195    }
196
197    /// Stream from the canonical [`PromptIR`] — the single, rich, provider-agnostic
198    /// request the engine emits once per round.
199    ///
200    /// A provider renders the IR into its own wire format by calling the lowering
201    /// methods ([`PromptIR::system_field`], [`PromptIR::body_chat`],
202    /// [`PromptIR::responses_input`], [`PromptIR::continuation_delta`]). The IR
203    /// carries the stateful Responses continuation, so an adapter derives the
204    /// delta itself rather than the engine pre-baking it.
205    ///
206    /// The default implementation lowers the IR for BOTH wire families and
207    /// delegates to [`chat_stream_with_options`](Self::chat_stream_with_options):
208    /// - the flat message list (`continuation_delta` mid-tool-loop, else `flatten`)
209    ///   for the Chat-Completions path;
210    /// - the Responses-API view (`instructions` / `input_messages` /
211    ///   `previous_response_id`) derived via [`PromptIR::responses_request_options`]
212    ///   and merged onto the request POLICY, so a Responses provider works WITHOUT
213    ///   overriding this method (Chat-Completions providers ignore those options).
214    ///
215    /// This is byte-identical to the pre-IR request. Block-native providers (e.g.
216    /// Anthropic) still override this to consume `system_blocks` structurally.
217    async fn chat_stream_ir(
218        &self,
219        ir: &PromptIR,
220        tools: &[ToolSchema],
221        max_output_tokens: Option<u32>,
222        model: &str,
223        options: Option<&LLMRequestOptions>,
224    ) -> Result<LLMStream> {
225        let messages = if ir.continuation.is_some() {
226            ir.continuation_delta()
227        } else {
228            ir.flatten()
229        };
230        let mut effective_options = options.cloned().unwrap_or_default();
231        effective_options.responses =
232            Some(ir.responses_request_options(effective_options.responses.as_ref()));
233        self.chat_stream_with_options(
234            &messages,
235            tools,
236            max_output_tokens,
237            model,
238            Some(&effective_options),
239        )
240        .await
241    }
242
243    /// Lists available models from this provider
244    ///
245    /// Returns a list of model identifiers that can be used with `chat_stream`.
246    /// Default implementation returns an empty list.
247    async fn list_models(&self) -> Result<Vec<String>> {
248        // Default implementation returns empty list
249        Ok(vec![])
250    }
251
252    /// Lists available models with optional token limit metadata.
253    ///
254    /// Default implementation preserves backward compatibility by adapting
255    /// `list_models()` output into metadata entries without limits.
256    async fn list_model_info(&self) -> Result<Vec<ProviderModelInfo>> {
257        Ok(self
258            .list_models()
259            .await?
260            .into_iter()
261            .map(ProviderModelInfo::from_id)
262            .collect())
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use std::sync::{Arc, Mutex};
269
270    use async_trait::async_trait;
271    use futures::{stream, StreamExt};
272
273    use super::*;
274
275    #[tokio::test]
276    async fn chat_stream_ir_default_flattens_and_delegates() {
277        use crate::prompt_ir::{PromptIR, Segment, SegmentRole};
278
279        // A provider that captures the message list AND the options it is handed.
280        #[derive(Default)]
281        struct Capture {
282            seen: Arc<Mutex<Vec<Message>>>,
283            seen_responses: Arc<Mutex<Option<crate::provider::ResponsesRequestOptions>>>,
284        }
285        #[async_trait]
286        impl LLMProvider for Capture {
287            async fn chat_stream(
288                &self,
289                _m: &[Message],
290                _t: &[ToolSchema],
291                _mt: Option<u32>,
292                _model: &str,
293            ) -> Result<LLMStream> {
294                unreachable!("default chat_stream_ir must route via chat_stream_with_options")
295            }
296            async fn chat_stream_with_options(
297                &self,
298                messages: &[Message],
299                _t: &[ToolSchema],
300                _mt: Option<u32>,
301                _model: &str,
302                o: Option<&LLMRequestOptions>,
303            ) -> Result<LLMStream> {
304                *self.seen.lock().expect("seen lock") = messages.to_vec();
305                *self.seen_responses.lock().expect("resp lock") =
306                    o.and_then(|value| value.responses.clone());
307                Ok(Box::pin(stream::iter(Vec::<Result<LLMChunk>>::new())))
308            }
309        }
310
311        let cap = Capture::default();
312        let ir = PromptIR {
313            system_text: "sys".into(),
314            segments: vec![
315                Segment::new(SegmentRole::StablePrefix, vec![Message::user("guide")]),
316                Segment::new(SegmentRole::DynamicContext, vec![Message::user("dyn")]),
317                Segment::new(SegmentRole::Conversation, vec![Message::user("ask")]),
318            ],
319            ..PromptIR::default()
320        };
321        let _ = cap
322            .chat_stream_ir(&ir, &[], None, "m", None)
323            .await
324            .expect("ir stream");
325
326        let seen = cap.seen.lock().expect("seen lock").clone();
327        let expected = ir.flatten();
328        assert_eq!(seen.len(), expected.len(), "delegates the flattened IR");
329        for (got, want) in seen.iter().zip(expected.iter()) {
330            assert_eq!(got.role, want.role);
331            assert_eq!(got.content, want.content);
332        }
333        // system + guide + dyn + ask
334        assert_eq!(seen.len(), 4);
335        assert!(matches!(seen[0].role, bamboo_domain::Role::System));
336
337        // SAFETY NET: the default also derives the Responses-API view from the IR, so
338        // a Responses provider works without overriding `chat_stream_ir`. instructions
339        // = the (trimmed) system field; input_messages = the full responses_input view
340        // (system lifted out, so it does not lead with a system message).
341        let responses = cap
342            .seen_responses
343            .lock()
344            .expect("resp lock")
345            .clone()
346            .expect("default derives Responses options from the IR");
347        assert_eq!(responses.instructions.as_deref(), Some("sys"));
348        let input = responses.input_messages.expect("input_messages derived");
349        assert_eq!(
350            input.iter().map(|m| m.content.clone()).collect::<Vec<_>>(),
351            vec!["guide".to_string(), "dyn".to_string(), "ask".to_string()],
352            "input_messages is the responses_input view: NO leading system message"
353        );
354    }
355
356    #[derive(Clone, Default)]
357    struct RecordingProvider {
358        requested_models: Arc<Mutex<Vec<String>>>,
359        requested_max_tokens: Arc<Mutex<Vec<Option<u32>>>>,
360    }
361
362    #[async_trait]
363    impl LLMProvider for RecordingProvider {
364        async fn chat_stream(
365            &self,
366            _messages: &[Message],
367            _tools: &[ToolSchema],
368            max_output_tokens: Option<u32>,
369            model: &str,
370        ) -> Result<LLMStream> {
371            if let Ok(mut models) = self.requested_models.lock() {
372                models.push(model.to_string());
373            }
374            if let Ok(mut max_tokens) = self.requested_max_tokens.lock() {
375                max_tokens.push(max_output_tokens);
376            }
377
378            Ok(Box::pin(stream::empty()))
379        }
380    }
381
382    #[tokio::test]
383    async fn chat_stream_with_options_delegates_to_chat_stream_with_same_model_and_tokens() {
384        let provider = RecordingProvider::default();
385        let options = LLMRequestOptions::default();
386
387        let mut stream = provider
388            .chat_stream_with_options(&[], &[], Some(512), "gpt-test", Some(&options))
389            .await
390            .expect("delegation should succeed");
391        assert!(stream.next().await.is_none());
392
393        assert_eq!(
394            provider
395                .requested_models
396                .lock()
397                .expect("lock poisoned")
398                .as_slice(),
399            ["gpt-test"]
400        );
401        assert_eq!(
402            provider
403                .requested_max_tokens
404                .lock()
405                .expect("lock poisoned")
406                .as_slice(),
407            [Some(512)]
408        );
409    }
410
411    #[tokio::test]
412    async fn list_models_returns_empty_by_default() {
413        let provider = RecordingProvider::default();
414        let models = provider
415            .list_models()
416            .await
417            .expect("default list_models should succeed");
418        assert!(models.is_empty());
419    }
420
421    #[test]
422    fn request_options_default_has_no_purpose() {
423        let opts = LLMRequestOptions::default();
424        assert!(opts.request_purpose.is_none());
425    }
426
427    #[test]
428    fn request_options_purpose_is_set_and_readable() {
429        let opts = LLMRequestOptions {
430            request_purpose: Some("title_generation".to_string()),
431            ..Default::default()
432        };
433        assert_eq!(opts.request_purpose.as_deref(), Some("title_generation"));
434    }
435}