use crate::types::LLMChunk;
use async_trait::async_trait;
use bamboo_domain::Message;
use bamboo_domain::ReasoningEffort;
use bamboo_domain::ToolSchema;
use futures::Stream;
use std::pin::Pin;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum LLMError {
#[error("HTTP error: {0}")]
Http(#[from] reqwest::Error),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("Stream error: {0}")]
Stream(String),
#[error("API error: {0}")]
Api(String),
#[error("Authentication error: {0}")]
Auth(String),
#[error("Protocol conversion error: {0}")]
Protocol(#[from] crate::protocol::ProtocolError),
}
pub type Result<T> = std::result::Result<T, LLMError>;
pub type LLMStream = Pin<Box<dyn Stream<Item = Result<LLMChunk>> + Send>>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ProviderModelInfo {
pub id: String,
pub max_context_tokens: Option<u32>,
pub max_output_tokens: Option<u32>,
}
impl ProviderModelInfo {
pub fn from_id(id: impl Into<String>) -> Self {
Self {
id: id.into(),
max_context_tokens: None,
max_output_tokens: None,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ResponsesRequestOptions {
pub instructions: Option<String>,
pub input_messages: Option<Vec<Message>>,
pub reasoning_summary: Option<String>,
pub include: Option<Vec<String>>,
pub store: Option<bool>,
pub previous_response_id: Option<String>,
pub truncation: Option<String>,
pub text_verbosity: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct LLMRequestOptions {
pub session_id: Option<String>,
pub reasoning_effort: Option<ReasoningEffort>,
pub parallel_tool_calls: Option<bool>,
pub responses: Option<ResponsesRequestOptions>,
pub request_purpose: Option<String>,
pub cache: Option<crate::cache::PromptCachePlan>,
}
#[derive(Debug, Clone, Default)]
pub struct PromptLanes {
pub stable_instructions: String,
pub stable_prefix_messages: Vec<Message>,
pub dynamic_context_messages: Vec<Message>,
pub conversation_messages: Vec<Message>,
}
impl PromptLanes {
pub fn flatten(&self) -> Vec<Message> {
let mut messages = Vec::with_capacity(
1 + self.stable_prefix_messages.len()
+ self.dynamic_context_messages.len()
+ self.conversation_messages.len(),
);
if !self.stable_instructions.trim().is_empty() {
messages.push(Message::system(self.stable_instructions.trim().to_string()));
}
messages.extend(self.stable_prefix_messages.iter().cloned());
messages.extend(self.dynamic_context_messages.iter().cloned());
messages.extend(self.conversation_messages.iter().cloned());
messages
}
}
#[async_trait]
pub trait LLMProvider: Send + Sync {
async fn chat_stream(
&self,
messages: &[Message],
tools: &[ToolSchema],
max_output_tokens: Option<u32>,
model: &str,
) -> Result<LLMStream>;
async fn chat_stream_with_options(
&self,
messages: &[Message],
tools: &[ToolSchema],
max_output_tokens: Option<u32>,
model: &str,
_options: Option<&LLMRequestOptions>,
) -> Result<LLMStream> {
self.chat_stream(messages, tools, max_output_tokens, model)
.await
}
async fn chat_stream_lanes(
&self,
lanes: &PromptLanes,
tools: &[ToolSchema],
max_output_tokens: Option<u32>,
model: &str,
options: Option<&LLMRequestOptions>,
) -> Result<LLMStream> {
let messages = lanes.flatten();
self.chat_stream_with_options(&messages, tools, max_output_tokens, model, options)
.await
}
async fn list_models(&self) -> Result<Vec<String>> {
Ok(vec![])
}
async fn list_model_info(&self) -> Result<Vec<ProviderModelInfo>> {
Ok(self
.list_models()
.await?
.into_iter()
.map(ProviderModelInfo::from_id)
.collect())
}
}
#[cfg(test)]
mod tests {
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use futures::{stream, StreamExt};
use super::*;
#[test]
fn prompt_lanes_flatten_preserves_canonical_order() {
let lanes = PromptLanes {
stable_instructions: " base system ".to_string(),
stable_prefix_messages: vec![Message::user("tool-guide")],
dynamic_context_messages: vec![Message::user("task-snapshot")],
conversation_messages: vec![Message::user("real ask")],
};
let flat = lanes.flatten();
assert_eq!(flat.len(), 4);
assert!(matches!(flat[0].role, bamboo_domain::Role::System));
assert_eq!(flat[0].content, "base system"); assert_eq!(flat[1].content, "tool-guide");
assert_eq!(flat[2].content, "task-snapshot");
assert_eq!(flat[3].content, "real ask");
}
#[tokio::test]
async fn chat_stream_lanes_default_flattens_and_delegates() {
#[derive(Default)]
struct Capture {
seen: Arc<Mutex<Vec<Message>>>,
}
#[async_trait]
impl LLMProvider for Capture {
async fn chat_stream(
&self,
_m: &[Message],
_t: &[ToolSchema],
_mt: Option<u32>,
_model: &str,
) -> Result<LLMStream> {
unreachable!("default chat_stream_lanes must route via chat_stream_with_options")
}
async fn chat_stream_with_options(
&self,
messages: &[Message],
_t: &[ToolSchema],
_mt: Option<u32>,
_model: &str,
_o: Option<&LLMRequestOptions>,
) -> Result<LLMStream> {
*self.seen.lock().expect("seen lock") = messages.to_vec();
Ok(Box::pin(stream::iter(Vec::<Result<LLMChunk>>::new())))
}
}
let cap = Capture::default();
let lanes = PromptLanes {
stable_instructions: "sys".into(),
stable_prefix_messages: vec![Message::user("guide")],
dynamic_context_messages: vec![Message::user("dyn")],
conversation_messages: vec![Message::user("ask")],
};
let _ = cap
.chat_stream_lanes(&lanes, &[], None, "m", None)
.await
.expect("lanes stream");
let seen = cap.seen.lock().expect("seen lock").clone();
let expected = lanes.flatten();
assert_eq!(seen.len(), expected.len(), "delegates the flattened lanes");
for (got, want) in seen.iter().zip(expected.iter()) {
assert_eq!(got.role, want.role);
assert_eq!(got.content, want.content);
}
assert_eq!(seen.len(), 4);
assert!(matches!(seen[0].role, bamboo_domain::Role::System));
}
#[test]
fn prompt_lanes_flatten_omits_empty_system() {
let lanes = PromptLanes {
stable_instructions: " ".to_string(),
conversation_messages: vec![Message::user("hi")],
..PromptLanes::default()
};
let flat = lanes.flatten();
assert_eq!(flat.len(), 1);
assert!(matches!(flat[0].role, bamboo_domain::Role::User));
}
#[derive(Clone, Default)]
struct RecordingProvider {
requested_models: Arc<Mutex<Vec<String>>>,
requested_max_tokens: Arc<Mutex<Vec<Option<u32>>>>,
}
#[async_trait]
impl LLMProvider for RecordingProvider {
async fn chat_stream(
&self,
_messages: &[Message],
_tools: &[ToolSchema],
max_output_tokens: Option<u32>,
model: &str,
) -> Result<LLMStream> {
if let Ok(mut models) = self.requested_models.lock() {
models.push(model.to_string());
}
if let Ok(mut max_tokens) = self.requested_max_tokens.lock() {
max_tokens.push(max_output_tokens);
}
Ok(Box::pin(stream::empty()))
}
}
#[tokio::test]
async fn chat_stream_with_options_delegates_to_chat_stream_with_same_model_and_tokens() {
let provider = RecordingProvider::default();
let options = LLMRequestOptions::default();
let mut stream = provider
.chat_stream_with_options(&[], &[], Some(512), "gpt-test", Some(&options))
.await
.expect("delegation should succeed");
assert!(stream.next().await.is_none());
assert_eq!(
provider
.requested_models
.lock()
.expect("lock poisoned")
.as_slice(),
["gpt-test"]
);
assert_eq!(
provider
.requested_max_tokens
.lock()
.expect("lock poisoned")
.as_slice(),
[Some(512)]
);
}
#[tokio::test]
async fn list_models_returns_empty_by_default() {
let provider = RecordingProvider::default();
let models = provider
.list_models()
.await
.expect("default list_models should succeed");
assert!(models.is_empty());
}
#[test]
fn request_options_default_has_no_purpose() {
let opts = LLMRequestOptions::default();
assert!(opts.request_purpose.is_none());
}
#[test]
fn request_options_purpose_is_set_and_readable() {
let opts = LLMRequestOptions {
request_purpose: Some("title_generation".to_string()),
..Default::default()
};
assert_eq!(opts.request_purpose.as_deref(), Some("title_generation"));
}
}