use crate::config::Config;
use crate::session::Message;
use tokio::sync::watch;
pub use octolib::llm::{
AiProvider, AmazonBedrockProvider, AnthropicProvider, CloudflareWorkersAiProvider,
DeepSeekProvider, GenericToolCall, GoogleVertexProvider, OpenAiProvider, OpenRouterProvider,
ProviderFactory, StructuredOutputRequest,
};
pub use octolib::llm::{ModelPricing, ProviderExchange, ThinkingBlock, TokenUsage};
#[derive(Debug, Clone)]
pub struct ProviderResponse {
pub content: String,
pub exchange: ProviderExchange,
pub tool_calls: Option<Vec<crate::mcp::McpToolCall>>,
pub thinking: Option<ThinkingBlock>,
pub finish_reason: Option<String>,
pub response_id: Option<String>,
pub structured_output: Option<serde_json::Value>,
}
#[derive(Clone)]
pub struct ChatCompletionParams<'a> {
pub messages: &'a [Message],
pub model: &'a str,
pub temperature: f32,
pub top_p: f32,
pub top_k: u32,
pub max_tokens: u32,
pub max_retries: u32,
pub retry_timeout: std::time::Duration,
pub config: &'a Config,
pub cancellation_token: Option<watch::Receiver<bool>>,
pub schema: Option<serde_json::Value>,
}
impl<'a> ChatCompletionParams<'a> {
pub fn new(
messages: &'a [Message],
model: &'a str,
temperature: f32,
top_p: f32,
top_k: u32,
max_tokens: u32,
config: &'a Config,
) -> Self {
Self {
messages,
model,
temperature,
top_p,
top_k,
max_tokens,
max_retries: config.max_retries,
retry_timeout: std::time::Duration::from_secs(config.retry_timeout as u64),
config,
cancellation_token: None,
schema: None,
}
}
pub fn with_max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = max_retries;
self
}
pub fn with_cancellation_token(mut self, token: watch::Receiver<bool>) -> Self {
self.cancellation_token = Some(token);
self
}
pub fn with_schema(mut self, schema: serde_json::Value) -> Self {
self.schema = Some(schema);
self
}
pub async fn to_octolib_params(
&self,
) -> Result<octolib::llm::ChatCompletionParams, octolib::MessageError> {
let octolib_messages: Result<Vec<octolib::llm::Message>, _> = self
.messages
.iter()
.map(convert_message_to_octolib)
.collect();
let mut octolib_messages = octolib_messages?;
let last_non_system_is_assistant = octolib_messages
.iter()
.rev()
.find(|m| m.role != "system")
.map(|m| m.role == "assistant")
.unwrap_or(false);
if last_non_system_is_assistant {
crate::log_debug!(
"Last message is assistant after compression - appending synthetic user message to satisfy provider requirements"
);
let synthetic = octolib::llm::MessageBuilder::user("Please continue.")
.build()
.map_err(|_| octolib::MessageError::InvalidRole {
role: "synthetic_user".to_string(),
})?;
octolib_messages.push(synthetic);
}
let mut params = octolib::llm::ChatCompletionParams::new(
&octolib_messages,
self.model,
self.temperature,
self.top_p,
self.top_k,
self.max_tokens,
)
.with_max_retries(self.max_retries)
.with_retry_timeout(self.retry_timeout);
if let Some(token) = &self.cancellation_token {
params = params.with_cancellation_token(token.clone());
}
if !self.config.mcp.servers.is_empty() {
let mcp_functions = crate::mcp::get_available_functions(self.config).await;
if !mcp_functions.is_empty() {
let mut octolib_tools: Vec<octolib::llm::FunctionDefinition> = mcp_functions
.into_iter()
.map(|f| octolib::llm::FunctionDefinition {
name: f.name,
description: f.description,
parameters: f.parameters,
cache_control: None, })
.collect();
let system_cached = self.messages.iter().any(|m| m.role == "system" && m.cached);
if system_cached && !octolib_tools.is_empty() {
if let Some(last_tool) = octolib_tools.last_mut() {
let ttl = if self.config.use_long_system_cache {
"1h"
} else {
"5m"
};
last_tool.cache_control = Some(serde_json::json!({
"type": "ephemeral",
"ttl": ttl
}));
}
}
params = params.with_tools(octolib_tools);
}
}
if let Some(ref schema) = self.schema {
params = params.with_structured_output(
StructuredOutputRequest::json_schema(schema.clone()).with_strict_mode(),
);
}
Ok(params)
}
}
fn convert_message_to_octolib(
msg: &Message,
) -> Result<octolib::llm::Message, octolib::MessageError> {
let mut builder = match msg.role.as_str() {
"user" => octolib::llm::MessageBuilder::user(&msg.content),
"assistant" => {
let mut builder = octolib::llm::MessageBuilder::assistant(&msg.content);
if let Some(ref tool_calls) = msg.tool_calls {
let generic_calls = convert_to_generic_tool_calls(tool_calls);
if !generic_calls.is_empty() {
builder = builder.with_tool_calls(generic_calls);
}
}
builder
}
"system" => octolib::llm::MessageBuilder::system(&msg.content),
"tool" => {
let tool_call_id = msg.tool_call_id.as_deref().ok_or_else(|| {
octolib::MessageError::MissingToolField {
field: "tool_call_id".to_string(),
}
})?;
let name =
msg.name
.as_deref()
.ok_or_else(|| octolib::MessageError::MissingToolField {
field: "name".to_string(),
})?;
octolib::llm::MessageBuilder::tool(
msg.content.clone(),
tool_call_id.to_string(),
name.to_string(),
)
}
_ => {
return Err(octolib::MessageError::InvalidRole {
role: msg.role.clone(),
})
}
};
builder = builder.timestamp(msg.timestamp);
if let Some(ref id) = msg.id {
builder = builder.id(id);
}
if msg.cached {
builder = builder.cached();
}
if let Some(images) = &msg.images {
let octolib_images: Vec<octolib::llm::ImageAttachment> =
images.iter().map(convert_image_to_octolib).collect();
builder = builder.with_images(octolib_images);
}
if let Some(videos) = &msg.videos {
let octolib_videos: Vec<octolib::llm::VideoAttachment> =
videos.iter().map(convert_video_to_octolib).collect();
builder = builder.with_videos(octolib_videos);
}
if let Some(ref thinking_value) = msg.thinking {
match serde_json::from_value::<octolib::ThinkingBlock>(thinking_value.clone()) {
Ok(thinking_block) => {
builder = builder.thinking(thinking_block);
}
Err(e) => {
crate::log_debug!(
"Failed to deserialize thinking field for {} message: {}. Value: {:?}",
msg.role,
e,
thinking_value
);
}
}
}
builder.build()
}
fn convert_image_to_octolib(
img: &crate::session::image::ImageAttachment,
) -> octolib::llm::ImageAttachment {
let data = match &img.data {
crate::session::image::ImageData::Base64(data) => {
octolib::llm::ImageData::Base64(data.clone())
}
crate::session::image::ImageData::Url(url) => octolib::llm::ImageData::Url(url.clone()),
};
let source_type = match &img.source_type {
crate::session::image::SourceType::File(path) => {
octolib::llm::SourceType::File(path.clone())
}
crate::session::image::SourceType::Clipboard => octolib::llm::SourceType::Clipboard,
crate::session::image::SourceType::Url => octolib::llm::SourceType::Url,
};
octolib::llm::ImageAttachment {
data,
media_type: img.media_type.clone(),
source_type,
dimensions: img.dimensions,
size_bytes: img.size_bytes,
}
}
fn convert_video_to_octolib(
video: &crate::session::video::VideoAttachment,
) -> octolib::llm::VideoAttachment {
let data = match &video.data {
crate::session::video::VideoData::Base64(data) => {
octolib::llm::VideoData::Base64(data.clone())
}
crate::session::video::VideoData::Url(url) => octolib::llm::VideoData::Url(url.clone()),
};
let source_type = match &video.source_type {
crate::session::video::SourceType::File(path) => {
octolib::llm::SourceType::File(path.clone())
}
crate::session::video::SourceType::Clipboard => octolib::llm::SourceType::Clipboard,
crate::session::video::SourceType::Url => octolib::llm::SourceType::Url,
};
octolib::llm::VideoAttachment {
data,
media_type: video.media_type.clone(),
source_type,
dimensions: video.dimensions,
size_bytes: video.size_bytes,
duration_secs: video.duration_secs,
}
}
fn convert_to_generic_tool_calls(
tool_calls: &serde_json::Value,
) -> Vec<octolib::llm::GenericToolCall> {
if let Ok(calls) =
serde_json::from_value::<Vec<octolib::llm::GenericToolCall>>(tool_calls.clone())
{
return calls;
}
if let Some(calls_array) = tool_calls.as_array() {
let mut generic_calls = Vec::new();
for call in calls_array {
if let Some(function) = call.get("function") {
if let (Some(id), Some(name), Some(args_str)) = (
call.get("id").and_then(|v| v.as_str()),
function.get("name").and_then(|v| v.as_str()),
function.get("arguments").and_then(|v| v.as_str()),
) {
let arguments = if args_str.trim().is_empty() {
serde_json::json!({})
} else {
match serde_json::from_str::<serde_json::Value>(args_str) {
Ok(json_args) => json_args,
Err(e) => {
panic!("Failed to parse tool call arguments '{}': {}", args_str, e);
}
}
};
generic_calls.push(octolib::llm::GenericToolCall {
id: id.to_string(),
name: name.to_string(),
arguments,
meta: None, });
} else {
panic!("Invalid OpenAI tool call format - missing required fields");
}
} else {
panic!("Invalid tool call format - missing 'function' field");
}
}
return generic_calls;
}
panic!("Unsupported tool_calls format - must be Vec<GenericToolCall> or OpenAI format array");
}
pub fn convert_response_from_octolib(response: octolib::llm::ProviderResponse) -> ProviderResponse {
let tool_calls = response.tool_calls.map(|calls| {
calls
.into_iter()
.map(|call| crate::mcp::McpToolCall {
tool_name: call.name,
tool_id: call.id,
parameters: call.arguments,
})
.collect()
});
ProviderResponse {
content: response.content,
exchange: response.exchange,
tool_calls,
thinking: response.thinking,
finish_reason: response.finish_reason,
response_id: response.id,
structured_output: response.structured_output,
}
}
pub mod retry {
pub use octolib::llm::retry::*;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_thinking_block_conversion() {
let thinking_block = ThinkingBlock {
content: "Test thinking content".to_string(),
tokens: 42,
};
let json_value = serde_json::to_value(&thinking_block).expect("Failed to serialize");
println!("Serialized: {}", json_value);
let deserialized: ThinkingBlock =
serde_json::from_value(json_value).expect("Failed to deserialize");
println!("Deserialized: {:?}", deserialized);
assert_eq!(deserialized.content, "Test thinking content");
assert_eq!(deserialized.tokens, 42);
}
}