use crate::config::Config;
use crate::providers::{ChatCompletionParams, ProviderFactory, ProviderResponse};
use crate::session::token_counter::{estimate_full_context_tokens, estimate_session_tokens};
use crate::session::Message;
use anyhow::Result;
use tokio::sync::watch;
pub struct ChatCompletionWithValidationParams<'a> {
pub messages: &'a [Message],
pub model: &'a str,
pub temperature: f32,
pub top_p: f32,
pub top_k: u32,
pub max_tokens: u32,
pub max_retries: u32,
pub config: &'a Config,
pub chat_session: Option<&'a mut crate::session::chat::session::ChatSession>,
pub cancellation_token: Option<watch::Receiver<bool>>,
pub schema: Option<serde_json::Value>,
}
impl<'a> ChatCompletionWithValidationParams<'a> {
pub fn new(
messages: &'a [Message],
model: &'a str,
temperature: f32,
top_p: f32,
top_k: u32,
max_tokens: u32,
config: &'a Config,
) -> Self {
Self {
messages,
model,
temperature,
top_p,
top_k,
max_tokens,
max_retries: 0,
config,
chat_session: None,
cancellation_token: None,
schema: None,
}
}
pub fn with_max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = max_retries;
self
}
pub fn with_chat_session(
mut self,
chat_session: &'a mut crate::session::chat::session::ChatSession,
) -> Self {
self.chat_session = Some(chat_session);
self
}
pub fn with_cancellation_token(mut self, token: watch::Receiver<bool>) -> Self {
self.cancellation_token = Some(token);
self
}
pub fn with_schema(mut self, schema: serde_json::Value) -> Self {
self.schema = Some(schema);
self
}
}
pub struct ChatCompletionProviderParams<'a> {
pub messages: &'a [Message],
pub model: &'a str,
pub temperature: f32,
pub top_p: f32,
pub top_k: u32,
pub max_tokens: u32,
pub config: &'a Config,
pub max_retries: u32,
pub cancellation_token: Option<watch::Receiver<bool>>,
pub schema: Option<serde_json::Value>,
}
pub async fn chat_completion_with_validation(
params: ChatCompletionWithValidationParams<'_>,
) -> Result<ProviderResponse> {
if let Some(ref token) = params.cancellation_token {
if *token.borrow() {
return Err(anyhow::anyhow!("Request cancelled before validation"));
}
}
let (provider, actual_model) = ProviderFactory::get_provider_for_model(params.model)?;
let max_input_tokens = provider.get_max_input_tokens(&actual_model);
let total_input_tokens = if params.chat_session.is_some() {
let tools = crate::mcp::get_available_functions(params.config).await;
estimate_full_context_tokens(
params.messages,
if tools.is_empty() { None } else { Some(&tools) },
)
} else {
estimate_session_tokens(params.messages)
};
if total_input_tokens > max_input_tokens {
return Err(anyhow::anyhow!(
"Input size ({} tokens) exceeds provider limit ({} tokens) for {} {}",
total_input_tokens,
max_input_tokens,
provider.name(),
actual_model
));
}
if let Some(ref token) = params.cancellation_token {
if *token.borrow() {
return Err(anyhow::anyhow!("Request cancelled before API call"));
}
}
let chat_params = ChatCompletionParams::new(
params.messages,
&actual_model,
params.temperature,
params.top_p,
params.top_k,
params.max_tokens,
params.config,
)
.with_max_retries(params.max_retries);
let chat_params = if let Some(schema) = params.schema {
chat_params.with_schema(schema)
} else {
chat_params
};
let chat_params = if let Some(token) = params.cancellation_token {
chat_params.with_cancellation_token(token)
} else {
chat_params
};
let octolib_params = chat_params
.to_octolib_params()
.await
.map_err(|e| anyhow::anyhow!("Failed to convert message parameters: {}", e))?;
let octolib_response = provider.chat_completion(octolib_params).await?;
Ok(crate::providers::convert_response_from_octolib(
octolib_response,
))
}
pub async fn chat_completion_with_provider(
params: ChatCompletionProviderParams<'_>,
) -> Result<ProviderResponse> {
let (provider, actual_model) = ProviderFactory::get_provider_for_model(params.model)?;
if params.schema.is_some() && !provider.supports_structured_output(&actual_model) {
return Err(anyhow::anyhow!(
"Provider '{}' does not support structured output for model '{}'. Remove --schema or use a compatible provider.",
provider.name(),
actual_model
));
}
let chat_params = ChatCompletionParams::new(
params.messages,
&actual_model,
params.temperature,
params.top_p,
params.top_k,
params.max_tokens,
params.config,
)
.with_max_retries(params.max_retries);
let chat_params = if let Some(schema) = params.schema {
chat_params.with_schema(schema)
} else {
chat_params
};
let octolib_params = chat_params
.to_octolib_params()
.await
.map_err(|e| anyhow::anyhow!("Failed to convert message parameters: {}", e))?;
let octolib_response = provider.chat_completion(octolib_params).await?;
Ok(crate::providers::convert_response_from_octolib(
octolib_response,
))
}