use crate::llm::types::LLMChunk;
use async_trait::async_trait;
use bamboo_domain::Message;
use bamboo_domain::ReasoningEffort;
use bamboo_domain::ToolSchema;
use futures::Stream;
use std::pin::Pin;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum LLMError {
#[error("HTTP error: {0}")]
Http(#[from] reqwest::Error),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("Stream error: {0}")]
Stream(String),
#[error("API error: {0}")]
Api(String),
#[error("Authentication error: {0}")]
Auth(String),
#[error("Protocol conversion error: {0}")]
Protocol(#[from] crate::llm::protocol::ProtocolError),
}
pub type Result<T> = std::result::Result<T, LLMError>;
pub type LLMStream = Pin<Box<dyn Stream<Item = Result<LLMChunk>> + Send>>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ProviderModelInfo {
pub id: String,
pub max_context_tokens: Option<u32>,
pub max_output_tokens: Option<u32>,
}
impl ProviderModelInfo {
pub fn from_id(id: impl Into<String>) -> Self {
Self {
id: id.into(),
max_context_tokens: None,
max_output_tokens: None,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ResponsesRequestOptions {
pub instructions: Option<String>,
pub reasoning_summary: Option<String>,
pub include: Option<Vec<String>>,
pub store: Option<bool>,
pub previous_response_id: Option<String>,
pub truncation: Option<String>,
pub text_verbosity: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct LLMRequestOptions {
pub session_id: Option<String>,
pub reasoning_effort: Option<ReasoningEffort>,
pub parallel_tool_calls: Option<bool>,
pub responses: Option<ResponsesRequestOptions>,
}
#[async_trait]
pub trait LLMProvider: Send + Sync {
async fn chat_stream(
&self,
messages: &[Message],
tools: &[ToolSchema],
max_output_tokens: Option<u32>,
model: &str,
) -> Result<LLMStream>;
async fn chat_stream_with_options(
&self,
messages: &[Message],
tools: &[ToolSchema],
max_output_tokens: Option<u32>,
model: &str,
_options: Option<&LLMRequestOptions>,
) -> Result<LLMStream> {
self.chat_stream(messages, tools, max_output_tokens, model)
.await
}
async fn list_models(&self) -> Result<Vec<String>> {
Ok(vec![])
}
async fn list_model_info(&self) -> Result<Vec<ProviderModelInfo>> {
Ok(self
.list_models()
.await?
.into_iter()
.map(ProviderModelInfo::from_id)
.collect())
}
}
#[cfg(test)]
mod tests {
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use futures::{stream, StreamExt};
use super::*;
#[derive(Clone, Default)]
struct RecordingProvider {
requested_models: Arc<Mutex<Vec<String>>>,
requested_max_tokens: Arc<Mutex<Vec<Option<u32>>>>,
}
#[async_trait]
impl LLMProvider for RecordingProvider {
async fn chat_stream(
&self,
_messages: &[Message],
_tools: &[ToolSchema],
max_output_tokens: Option<u32>,
model: &str,
) -> Result<LLMStream> {
if let Ok(mut models) = self.requested_models.lock() {
models.push(model.to_string());
}
if let Ok(mut max_tokens) = self.requested_max_tokens.lock() {
max_tokens.push(max_output_tokens);
}
Ok(Box::pin(stream::empty()))
}
}
#[tokio::test]
async fn chat_stream_with_options_delegates_to_chat_stream_with_same_model_and_tokens() {
let provider = RecordingProvider::default();
let options = LLMRequestOptions::default();
let mut stream = provider
.chat_stream_with_options(&[], &[], Some(512), "gpt-test", Some(&options))
.await
.expect("delegation should succeed");
assert!(stream.next().await.is_none());
assert_eq!(
provider
.requested_models
.lock()
.expect("lock poisoned")
.as_slice(),
["gpt-test"]
);
assert_eq!(
provider
.requested_max_tokens
.lock()
.expect("lock poisoned")
.as_slice(),
[Some(512)]
);
}
#[tokio::test]
async fn list_models_returns_empty_by_default() {
let provider = RecordingProvider::default();
let models = provider
.list_models()
.await
.expect("default list_models should succeed");
assert!(models.is_empty());
}
}