use super::tool_registry::ToolRegistry;
use super::types::FunctionCall;
use serde_json::Value;
use std::sync::Arc;
#[derive(Default)]
pub struct LLMEventHandlers {
pub on_token: Option<Arc<dyn Fn(&str) + Send + Sync>>,
pub on_start_thinking: Option<Arc<dyn Fn() + Send + Sync>>,
pub on_stop_thinking: Option<Arc<dyn Fn() + Send + Sync>>,
pub on_thinking: Option<Arc<dyn Fn(&str) + Send + Sync>>,
pub on_tool_call: Option<Arc<dyn Fn(&FunctionCall) + Send + Sync>>,
pub on_tool_result: Option<Arc<dyn Fn(&FunctionCall, &Result<Value, String>) + Send + Sync>>,
}
#[derive(Default)]
pub struct LLMOptions<'a> {
pub streaming: bool,
pub tools: Option<&'a ToolRegistry>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub min_p: Option<f32>,
pub top_k: Option<u32>,
pub max_tokens: Option<u32>,
pub stop: Option<Vec<String>>,
pub presence_penalty: Option<f32>,
pub frequency_penalty: Option<f32>,
pub repeat_penalty: Option<f32>,
pub context_overflow_policy: Option<String>,
pub event_handlers: LLMEventHandlers,
}
impl<'a> LLMOptions<'a> {
pub fn new() -> Self {
Self::default()
}
pub fn streaming(mut self, streaming: bool) -> Self {
self.streaming = streaming;
self
}
pub fn with_tools(mut self, tools: &'a ToolRegistry) -> Self {
self.tools = Some(tools);
self
}
pub fn temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
pub fn min_p(mut self, min_p: f32) -> Self {
self.top_p = Some(min_p);
self
}
pub fn top_k(mut self, top_k: u32) -> Self {
self.top_k = Some(top_k);
self
}
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn stop(mut self, stop: Vec<String>) -> Self {
self.stop = Some(stop);
self
}
pub fn presence_penalty(mut self, penalty: f32) -> Self {
self.presence_penalty = Some(penalty);
self
}
pub fn repeat_penalty(mut self, repeat_penalty: f32) -> Self {
self.repeat_penalty = Some(repeat_penalty);
self
}
pub fn context_overflow_policy(mut self, context_overflow_policy: &str) -> Self {
self.context_overflow_policy = Some(context_overflow_policy.to_string());
self
}
pub fn frequency_penalty(mut self, penalty: f32) -> Self {
self.frequency_penalty = Some(penalty);
self
}
pub fn on_token<F>(mut self, callback: F) -> Self
where
F: Fn(&str) + Send + Sync + 'static,
{
self.event_handlers.on_token = Some(Arc::new(callback));
self
}
pub fn on_start_thinking<F>(mut self, callback: F) -> Self
where
F: Fn() + Send + Sync + 'static,
{
self.event_handlers.on_start_thinking = Some(Arc::new(callback));
self
}
pub fn on_stop_thinking<F>(mut self, callback: F) -> Self
where
F: Fn() + Send + Sync + 'static,
{
self.event_handlers.on_stop_thinking = Some(Arc::new(callback));
self
}
pub fn on_thinking<F>(mut self, callback: F) -> Self
where
F: Fn(&str) + Send + Sync + 'static,
{
self.event_handlers.on_thinking = Some(Arc::new(callback));
self
}
pub fn on_tool_call<F>(mut self, callback: F) -> Self
where
F: Fn(&FunctionCall) + Send + Sync + 'static,
{
self.event_handlers.on_tool_call = Some(Arc::new(callback));
self
}
pub fn on_tool_result<F>(mut self, callback: F) -> Self
where
F: Fn(&FunctionCall, &Result<Value, String>) + Send + Sync + 'static,
{
self.event_handlers.on_tool_result = Some(Arc::new(callback));
self
}
}