use crate::types::Record;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ToolChoice {
Auto,
None,
Required,
Tool { name: String },
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ToolConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub allowed_tools: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub allowed_collections: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_iterations: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub allow_write_operations: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Models {
pub openai: Vec<String>,
pub anthropic: Vec<String>,
pub perplexity: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CollectionConfig {
pub collection_name: String,
pub fields: Vec<FieldSearchOptions>,
#[serde(skip_serializing_if = "Option::is_none")]
pub search_options: Option<TextSearchOptions>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FieldSearchOptions {
pub field: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub search_options: Option<TextSearchOptions>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct TextSearchOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub language: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub case_sensitive: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub fuzzy_match: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub min_score: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub enable_stemming: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub boost_exact_matches: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_edit_distance: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatRequest {
pub message: String,
pub collections: Vec<CollectionConfig>,
pub llm_provider: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub llm_model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_prompt: Option<String>,
}
impl ChatRequest {
pub fn new(message: impl Into<String>, llm_provider: impl Into<String>) -> Self {
Self {
message: message.into(),
collections: Vec::new(),
llm_provider: llm_provider.into(),
llm_model: None,
system_prompt: None,
}
}
pub fn collection(mut self, collection: CollectionConfig) -> Self {
self.collections.push(collection);
self
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.llm_model = Some(model.into());
self
}
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatResponse {
pub chat_id: String,
pub message_id: String,
pub responses: Vec<String>,
pub context_snippets: Vec<ContextSnippet>,
pub execution_time_ms: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub token_usage: Option<TokenUsage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenUsage {
pub prompt_tokens: Option<u32>,
pub completion_tokens: Option<u32>,
pub total_tokens: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextSnippet {
pub collection: String,
pub record: serde_json::Value,
pub score: f64,
pub matched_fields: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateChatSessionRequest {
pub collections: Vec<CollectionConfig>,
pub llm_provider: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub llm_model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_prompt: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub agent_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub bypass_ripple: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parent_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub branch_point_idx: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_context_messages: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_config: Option<ToolConfig>,
}
impl CreateChatSessionRequest {
pub fn new(llm_provider: impl Into<String>) -> Self {
Self {
collections: Vec::new(),
llm_provider: llm_provider.into(),
llm_model: None,
system_prompt: None,
agent_id: None,
bypass_ripple: None,
parent_id: None,
branch_point_idx: None,
max_context_messages: None,
max_tokens: None,
temperature: None,
tool_config: None,
}
}
pub fn collection(mut self, collection: CollectionConfig) -> Self {
self.collections.push(collection);
self
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.llm_model = Some(model.into());
self
}
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
pub fn agent_id(mut self, id: impl Into<String>) -> Self {
self.agent_id = Some(id.into());
self
}
pub fn branch_from(mut self, parent_id: impl Into<String>, branch_point_idx: usize) -> Self {
self.parent_id = Some(parent_id.into());
self.branch_point_idx = Some(branch_point_idx);
self
}
pub fn max_context_messages(mut self, max: usize) -> Self {
self.max_context_messages = Some(max);
self
}
pub fn max_tokens(mut self, max: i32) -> Self {
self.max_tokens = Some(max);
self
}
pub fn temperature(mut self, temp: f32) -> Self {
self.temperature = Some(temp);
self
}
pub fn tool_config(mut self, config: ToolConfig) -> Self {
self.tool_config = Some(config);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatSessionResponse {
#[serde(default)]
pub session: Record,
#[serde(default)]
pub message_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatSession {
pub chat_id: String,
pub created_at: String,
pub updated_at: String,
pub llm_provider: String,
pub llm_model: String,
pub collections: Vec<CollectionConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_prompt: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub agent_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
pub message_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Attachment {
pub mime_type: String,
pub data: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessageRequest {
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub bypass_ripple: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub force_summarize: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_iterations: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_config: Option<ToolConfig>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub llm_model: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub client_tools: Option<Vec<ClientToolDef>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub confirm_tools: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub exclude_tools: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub attachments: Option<Vec<Attachment>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientToolDef {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
impl ChatMessageRequest {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
bypass_ripple: None,
force_summarize: None,
max_iterations: None,
tool_config: None,
llm_model: None,
client_tools: None,
confirm_tools: None,
exclude_tools: None,
attachments: None,
}
}
pub fn attachments(mut self, attachments: Vec<Attachment>) -> Self {
self.attachments = Some(attachments);
self
}
pub fn force_summarize(mut self, force: bool) -> Self {
self.force_summarize = Some(force);
self
}
pub fn max_iterations(mut self, max: u32) -> Self {
self.max_iterations = Some(max);
self
}
pub fn tool_config(mut self, config: ToolConfig) -> Self {
self.tool_config = Some(config);
self
}
pub fn llm_model(mut self, model: impl Into<String>) -> Self {
self.llm_model = Some(model.into());
self
}
pub fn client_tools(mut self, tools: Vec<ClientToolDef>) -> Self {
self.client_tools = Some(tools);
self
}
pub fn confirm_tools(mut self, tools: Vec<String>) -> Self {
self.confirm_tools = Some(tools);
self
}
pub fn exclude_tools(mut self, tools: Vec<String>) -> Self {
self.exclude_tools = Some(tools);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MergeStrategy {
Chronological,
Summarized,
LatestOnly,
Interleaved,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MergeSessionsRequest {
pub source_chat_ids: Vec<String>,
pub target_chat_id: String,
pub merge_strategy: MergeStrategy,
#[serde(skip_serializing_if = "Option::is_none")]
pub bypass_ripple: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GetMessagesResponse {
pub messages: Vec<Record>,
pub total: usize,
pub skip: usize,
pub limit: Option<usize>,
pub returned: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct GetMessagesQuery {
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub skip: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sort: Option<String>,
}
impl GetMessagesQuery {
pub fn new() -> Self {
Self::default()
}
pub fn limit(mut self, limit: usize) -> Self {
self.limit = Some(limit);
self
}
pub fn skip(mut self, skip: usize) -> Self {
self.skip = Some(skip);
self
}
pub fn sort(mut self, sort: impl Into<String>) -> Self {
self.sort = Some(sort.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ListSessionsQuery {
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub skip: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sort: Option<String>,
}
impl ListSessionsQuery {
pub fn new() -> Self {
Self::default()
}
pub fn limit(mut self, limit: usize) -> Self {
self.limit = Some(limit);
self
}
pub fn skip(mut self, skip: usize) -> Self {
self.skip = Some(skip);
self
}
pub fn sort(mut self, sort: impl Into<String>) -> Self {
self.sort = Some(sort.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ListSessionsResponse {
pub sessions: Vec<ChatSession>,
pub total: usize,
pub returned: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct UpdateSessionRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub system_prompt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub llm_model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub collections: Option<Vec<CollectionConfig>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_context_messages: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub bypass_ripple: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub memory: Option<serde_json::Value>,
}
impl UpdateSessionRequest {
pub fn new() -> Self {
Self::default()
}
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.llm_model = Some(model.into());
self
}
pub fn title(mut self, title: impl Into<String>) -> Self {
self.title = Some(title.into());
self
}
pub fn collections(mut self, collections: Vec<CollectionConfig>) -> Self {
self.collections = Some(collections);
self
}
pub fn max_context_messages(mut self, max: usize) -> Self {
self.max_context_messages = Some(max);
self
}
pub fn memory(mut self, memory: serde_json::Value) -> Self {
self.memory = Some(memory);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpdateMessageRequest {
pub content: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToggleForgottenRequest {
pub forgotten: bool,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chat_request_builder() {
let request = ChatRequest::new("Hello", "openai")
.model("gpt-4")
.system_prompt("You are a helpful assistant");
assert_eq!(request.message, "Hello");
assert_eq!(request.llm_provider, "openai");
assert_eq!(request.llm_model, Some("gpt-4".to_string()));
assert!(request.system_prompt.is_some());
}
#[test]
fn test_create_session_request_builder() {
let request = CreateChatSessionRequest::new("openai")
.model("gpt-4")
.system_prompt("Test prompt");
assert_eq!(request.llm_provider, "openai");
assert_eq!(request.llm_model, Some("gpt-4".to_string()));
assert!(request.system_prompt.is_some());
}
#[test]
fn test_chat_message_request() {
let request = ChatMessageRequest::new("Hello").force_summarize(true);
assert_eq!(request.message, "Hello");
assert_eq!(request.force_summarize, Some(true));
}
#[test]
fn test_get_messages_query() {
let query = GetMessagesQuery::new().limit(10).skip(5).sort("desc");
assert_eq!(query.limit, Some(10));
assert_eq!(query.skip, Some(5));
assert_eq!(query.sort, Some("desc".to_string()));
}
#[test]
fn test_list_sessions_query() {
let query = ListSessionsQuery::new().limit(20).sort("asc");
assert_eq!(query.limit, Some(20));
assert_eq!(query.sort, Some("asc".to_string()));
}
#[test]
fn test_update_session_request() {
let request = UpdateSessionRequest::new()
.title("Updated Title")
.model("gpt-4-turbo");
assert_eq!(request.title, Some("Updated Title".to_string()));
assert_eq!(request.llm_model, Some("gpt-4-turbo".to_string()));
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RawCompletionRequest {
pub system_prompt: String,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub provider: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<i32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RawCompletionResponse {
pub content: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbedRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub texts: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbedResponse {
pub embeddings: Vec<Vec<f64>>,
pub model: String,
pub dimensions: usize,
}