use crate::prompt_ir::PromptIR;
use crate::types::LLMChunk;
use async_trait::async_trait;
use bamboo_domain::Message;
use bamboo_domain::ReasoningEffort;
use bamboo_domain::ToolSchema;
use futures::Stream;
use std::pin::Pin;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum LLMError {
#[error("HTTP error: {0}")]
Http(#[from] reqwest::Error),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("Stream error: {0}")]
Stream(String),
#[error("API error: {0}")]
Api(String),
#[error("Authentication error: {0}")]
Auth(String),
#[error("Protocol conversion error: {0}")]
Protocol(#[from] crate::protocol::ProtocolError),
}
pub type Result<T> = std::result::Result<T, LLMError>;
pub type LLMStream = Pin<Box<dyn Stream<Item = Result<LLMChunk>> + Send>>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ProviderModelInfo {
pub id: String,
pub max_context_tokens: Option<u32>,
pub max_output_tokens: Option<u32>,
}
impl ProviderModelInfo {
pub fn from_id(id: impl Into<String>) -> Self {
Self {
id: id.into(),
max_context_tokens: None,
max_output_tokens: None,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ResponsesRequestOptions {
pub instructions: Option<String>,
pub input_messages: Option<Vec<Message>>,
pub reasoning_summary: Option<String>,
pub include: Option<Vec<String>>,
pub store: Option<bool>,
pub previous_response_id: Option<String>,
pub truncation: Option<String>,
pub text_verbosity: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct LLMRequestOptions {
pub session_id: Option<String>,
pub reasoning_effort: Option<ReasoningEffort>,
pub parallel_tool_calls: Option<bool>,
pub responses: Option<ResponsesRequestOptions>,
pub request_purpose: Option<String>,
pub cache: Option<crate::cache::PromptCachePlan>,
}
#[async_trait]
pub trait LLMProvider: Send + Sync {
async fn chat_stream(
&self,
messages: &[Message],
tools: &[ToolSchema],
max_output_tokens: Option<u32>,
model: &str,
) -> Result<LLMStream>;
async fn chat_stream_with_options(
&self,
messages: &[Message],
tools: &[ToolSchema],
max_output_tokens: Option<u32>,
model: &str,
_options: Option<&LLMRequestOptions>,
) -> Result<LLMStream> {
self.chat_stream(messages, tools, max_output_tokens, model)
.await
}
async fn chat_stream_ir(
&self,
ir: &PromptIR,
tools: &[ToolSchema],
max_output_tokens: Option<u32>,
model: &str,
options: Option<&LLMRequestOptions>,
) -> Result<LLMStream> {
let messages = if ir.continuation.is_some() {
ir.continuation_delta()
} else {
ir.flatten()
};
let mut effective_options = options.cloned().unwrap_or_default();
effective_options.responses =
Some(ir.responses_request_options(effective_options.responses.as_ref()));
self.chat_stream_with_options(
&messages,
tools,
max_output_tokens,
model,
Some(&effective_options),
)
.await
}
async fn list_models(&self) -> Result<Vec<String>> {
Ok(vec![])
}
async fn list_model_info(&self) -> Result<Vec<ProviderModelInfo>> {
Ok(self
.list_models()
.await?
.into_iter()
.map(ProviderModelInfo::from_id)
.collect())
}
}
#[cfg(test)]
mod tests {
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use futures::{stream, StreamExt};
use super::*;
#[tokio::test]
async fn chat_stream_ir_default_flattens_and_delegates() {
use crate::prompt_ir::{PromptIR, Segment, SegmentRole};
#[derive(Default)]
struct Capture {
seen: Arc<Mutex<Vec<Message>>>,
seen_responses: Arc<Mutex<Option<crate::provider::ResponsesRequestOptions>>>,
}
#[async_trait]
impl LLMProvider for Capture {
async fn chat_stream(
&self,
_m: &[Message],
_t: &[ToolSchema],
_mt: Option<u32>,
_model: &str,
) -> Result<LLMStream> {
unreachable!("default chat_stream_ir must route via chat_stream_with_options")
}
async fn chat_stream_with_options(
&self,
messages: &[Message],
_t: &[ToolSchema],
_mt: Option<u32>,
_model: &str,
o: Option<&LLMRequestOptions>,
) -> Result<LLMStream> {
*self.seen.lock().expect("seen lock") = messages.to_vec();
*self.seen_responses.lock().expect("resp lock") =
o.and_then(|value| value.responses.clone());
Ok(Box::pin(stream::iter(Vec::<Result<LLMChunk>>::new())))
}
}
let cap = Capture::default();
let ir = PromptIR {
system_text: "sys".into(),
segments: vec![
Segment::new(SegmentRole::StablePrefix, vec![Message::user("guide")]),
Segment::new(SegmentRole::DynamicContext, vec![Message::user("dyn")]),
Segment::new(SegmentRole::Conversation, vec![Message::user("ask")]),
],
..PromptIR::default()
};
let _ = cap
.chat_stream_ir(&ir, &[], None, "m", None)
.await
.expect("ir stream");
let seen = cap.seen.lock().expect("seen lock").clone();
let expected = ir.flatten();
assert_eq!(seen.len(), expected.len(), "delegates the flattened IR");
for (got, want) in seen.iter().zip(expected.iter()) {
assert_eq!(got.role, want.role);
assert_eq!(got.content, want.content);
}
assert_eq!(seen.len(), 4);
assert!(matches!(seen[0].role, bamboo_domain::Role::System));
let responses = cap
.seen_responses
.lock()
.expect("resp lock")
.clone()
.expect("default derives Responses options from the IR");
assert_eq!(responses.instructions.as_deref(), Some("sys"));
let input = responses.input_messages.expect("input_messages derived");
assert_eq!(
input.iter().map(|m| m.content.clone()).collect::<Vec<_>>(),
vec!["guide".to_string(), "dyn".to_string(), "ask".to_string()],
"input_messages is the responses_input view: NO leading system message"
);
}
#[derive(Clone, Default)]
struct RecordingProvider {
requested_models: Arc<Mutex<Vec<String>>>,
requested_max_tokens: Arc<Mutex<Vec<Option<u32>>>>,
}
#[async_trait]
impl LLMProvider for RecordingProvider {
async fn chat_stream(
&self,
_messages: &[Message],
_tools: &[ToolSchema],
max_output_tokens: Option<u32>,
model: &str,
) -> Result<LLMStream> {
if let Ok(mut models) = self.requested_models.lock() {
models.push(model.to_string());
}
if let Ok(mut max_tokens) = self.requested_max_tokens.lock() {
max_tokens.push(max_output_tokens);
}
Ok(Box::pin(stream::empty()))
}
}
#[tokio::test]
async fn chat_stream_with_options_delegates_to_chat_stream_with_same_model_and_tokens() {
let provider = RecordingProvider::default();
let options = LLMRequestOptions::default();
let mut stream = provider
.chat_stream_with_options(&[], &[], Some(512), "gpt-test", Some(&options))
.await
.expect("delegation should succeed");
assert!(stream.next().await.is_none());
assert_eq!(
provider
.requested_models
.lock()
.expect("lock poisoned")
.as_slice(),
["gpt-test"]
);
assert_eq!(
provider
.requested_max_tokens
.lock()
.expect("lock poisoned")
.as_slice(),
[Some(512)]
);
}
#[tokio::test]
async fn list_models_returns_empty_by_default() {
let provider = RecordingProvider::default();
let models = provider
.list_models()
.await
.expect("default list_models should succeed");
assert!(models.is_empty());
}
#[test]
fn request_options_default_has_no_purpose() {
let opts = LLMRequestOptions::default();
assert!(opts.request_purpose.is_none());
}
#[test]
fn request_options_purpose_is_set_and_readable() {
let opts = LLMRequestOptions {
request_purpose: Some("title_generation".to_string()),
..Default::default()
};
assert_eq!(opts.request_purpose.as_deref(), Some("title_generation"));
}
}