1use super::provider::{LLMConfig, LLMProvider};
6use super::tool_executor::ToolExecutor;
7use super::types::*;
8use std::sync::Arc;
9
10pub struct LLMClient {
33 provider: Arc<dyn LLMProvider>,
34 config: LLMConfig,
35}
36
37impl LLMClient {
38 pub fn new(provider: Arc<dyn LLMProvider>) -> Self {
40 Self {
41 provider,
42 config: LLMConfig::default(),
43 }
44 }
45
46 pub fn with_config(provider: Arc<dyn LLMProvider>, config: LLMConfig) -> Self {
48 Self { provider, config }
49 }
50
51 pub fn provider(&self) -> &Arc<dyn LLMProvider> {
53 &self.provider
54 }
55
56 pub fn config(&self) -> &LLMConfig {
58 &self.config
59 }
60
61 pub fn chat(&self) -> ChatRequestBuilder {
63 let model = self
64 .config
65 .default_model
66 .clone()
67 .unwrap_or_else(|| self.provider.default_model().to_string());
68
69 let mut builder = ChatRequestBuilder::new(self.provider.clone(), model);
70
71 if let Some(temp) = self.config.default_temperature {
72 builder = builder.temperature(temp);
73 }
74 if let Some(tokens) = self.config.default_max_tokens {
75 builder = builder.max_tokens(tokens);
76 }
77
78 builder
79 }
80
81 pub async fn embed(&self, input: impl Into<String>) -> LLMResult<Vec<f32>> {
83 let model = self
84 .config
85 .default_model
86 .clone()
87 .unwrap_or_else(|| "text-embedding-ada-002".to_string());
88
89 let request = EmbeddingRequest {
90 model,
91 input: EmbeddingInput::Single(input.into()),
92 encoding_format: None,
93 dimensions: None,
94 user: None,
95 };
96
97 let response = self.provider.embedding(request).await?;
98 response
99 .data
100 .into_iter()
101 .next()
102 .map(|d| d.embedding)
103 .ok_or_else(|| LLMError::Other("No embedding data returned".to_string()))
104 }
105
106 pub async fn embed_batch(&self, inputs: Vec<String>) -> LLMResult<Vec<Vec<f32>>> {
108 let model = self
109 .config
110 .default_model
111 .clone()
112 .unwrap_or_else(|| "text-embedding-ada-002".to_string());
113
114 let request = EmbeddingRequest {
115 model,
116 input: EmbeddingInput::Multiple(inputs),
117 encoding_format: None,
118 dimensions: None,
119 user: None,
120 };
121
122 let response = self.provider.embedding(request).await?;
123 Ok(response.data.into_iter().map(|d| d.embedding).collect())
124 }
125
126 pub async fn ask(&self, question: impl Into<String>) -> LLMResult<String> {
128 let response = self.chat().user(question).send().await?;
129
130 response
131 .content()
132 .map(|s| s.to_string())
133 .ok_or_else(|| LLMError::Other("No content in response".to_string()))
134 }
135
136 pub async fn ask_with_system(
138 &self,
139 system: impl Into<String>,
140 question: impl Into<String>,
141 ) -> LLMResult<String> {
142 let response = self.chat().system(system).user(question).send().await?;
143
144 response
145 .content()
146 .map(|s| s.to_string())
147 .ok_or_else(|| LLMError::Other("No content in response".to_string()))
148 }
149}
150
151pub struct ChatRequestBuilder {
153 provider: Arc<dyn LLMProvider>,
154 request: ChatCompletionRequest,
155 tool_executor: Option<Arc<dyn ToolExecutor>>,
156 max_tool_rounds: u32,
157 retry_policy: Option<LLMRetryPolicy>,
159 retry_enabled: bool,
160}
161
162impl ChatRequestBuilder {
163 pub fn new(provider: Arc<dyn LLMProvider>, model: impl Into<String>) -> Self {
165 Self {
166 provider,
167 request: ChatCompletionRequest::new(model),
168 tool_executor: None,
169 max_tool_rounds: 10,
170 retry_policy: None,
171 retry_enabled: false,
172 }
173 }
174
175 pub fn system(mut self, content: impl Into<String>) -> Self {
177 self.request.messages.push(ChatMessage::system(content));
178 self
179 }
180
181 pub fn user(mut self, content: impl Into<String>) -> Self {
183 self.request.messages.push(ChatMessage::user(content));
184 self
185 }
186
187 pub fn user_with_content(mut self, content: MessageContent) -> Self {
189 self.request
190 .messages
191 .push(ChatMessage::user_with_content(content));
192 self
193 }
194
195 pub fn user_with_parts(mut self, parts: Vec<ContentPart>) -> Self {
197 self.request
198 .messages
199 .push(ChatMessage::user_with_parts(parts));
200 self
201 }
202
203 pub fn assistant(mut self, content: impl Into<String>) -> Self {
205 self.request.messages.push(ChatMessage::assistant(content));
206 self
207 }
208
209 pub fn message(mut self, message: ChatMessage) -> Self {
211 self.request.messages.push(message);
212 self
213 }
214
215 pub fn messages(mut self, messages: Vec<ChatMessage>) -> Self {
217 self.request.messages.extend(messages);
218 self
219 }
220
221 pub fn temperature(mut self, temp: f32) -> Self {
223 self.request.temperature = Some(temp);
224 self
225 }
226
227 pub fn max_tokens(mut self, tokens: u32) -> Self {
229 self.request.max_tokens = Some(tokens);
230 self
231 }
232
233 pub fn tool(mut self, tool: Tool) -> Self {
235 self.request.tools.get_or_insert_with(Vec::new).push(tool);
236 self
237 }
238
239 pub fn tools(mut self, tools: Vec<Tool>) -> Self {
241 self.request.tools = Some(tools);
242 self
243 }
244
245 pub fn with_tool_executor(mut self, executor: Arc<dyn ToolExecutor>) -> Self {
247 self.tool_executor = Some(executor);
248 self
249 }
250
251 pub fn max_tool_rounds(mut self, rounds: u32) -> Self {
253 self.max_tool_rounds = rounds;
254 self
255 }
256
257 pub fn json_mode(mut self) -> Self {
259 self.request.response_format = Some(ResponseFormat::json());
260 self
261 }
262
263 pub fn stop(mut self, sequences: Vec<String>) -> Self {
265 self.request.stop = Some(sequences);
266 self
267 }
268
269 pub fn with_retry(mut self) -> Self {
287 self.retry_enabled = true;
288 self.retry_policy = Some(LLMRetryPolicy::default());
289 self
290 }
291
292 pub fn with_retry_policy(mut self, policy: LLMRetryPolicy) -> Self {
314 self.retry_enabled = true;
315 self.retry_policy = Some(policy);
316 self
317 }
318
319 pub fn without_retry(mut self) -> Self {
324 self.retry_enabled = false;
325 self.retry_policy = None;
326 self
327 }
328
329 pub fn max_retries(mut self, max: u32) -> Self {
343 if self.retry_policy.is_none() {
344 self.retry_policy = Some(LLMRetryPolicy::default());
345 }
346 if let Some(ref mut policy) = self.retry_policy {
347 policy.max_attempts = max;
348 }
349 self.retry_enabled = true;
350 self
351 }
352
353 pub async fn send(self) -> LLMResult<ChatCompletionResponse> {
355 if self.retry_enabled {
356 let policy = self.retry_policy.unwrap_or_default();
357 let executor = crate::llm::retry::RetryExecutor::new(self.provider, policy);
358 executor.chat(self.request).await
359 } else {
360 self.provider.chat(self.request).await
361 }
362 }
363
364 pub async fn send_stream(mut self) -> LLMResult<super::provider::ChatStream> {
366 self.request.stream = Some(true);
367 self.provider.chat_stream(self.request).await
368 }
369
370 pub async fn send_with_tools(mut self) -> LLMResult<ChatCompletionResponse> {
375 let executor = self
376 .tool_executor
377 .take()
378 .ok_or_else(|| LLMError::ConfigError("Tool executor not set".to_string()))?;
379
380 if self
381 .request
382 .tools
383 .as_ref()
384 .map(|tools| tools.is_empty())
385 .unwrap_or(true)
386 {
387 let tools = executor.available_tools().await?;
388 if !tools.is_empty() {
389 self.request.tools = Some(tools);
390 }
391 }
392
393 let max_rounds = self.max_tool_rounds;
394 let mut round = 0;
395
396 loop {
397 let response = self.provider.chat(self.request.clone()).await?;
398
399 if !response.has_tool_calls() {
401 return Ok(response);
402 }
403
404 round += 1;
405 if round >= max_rounds {
406 return Err(LLMError::Other(format!(
407 "Max tool rounds ({}) exceeded",
408 max_rounds
409 )));
410 }
411
412 if let Some(choice) = response.choices.first() {
414 self.request.messages.push(choice.message.clone());
415 }
416
417 if let Some(tool_calls) = response.tool_calls() {
419 for tool_call in tool_calls {
420 let result = executor
421 .execute(&tool_call.function.name, &tool_call.function.arguments)
422 .await;
423
424 let result_str = match result {
425 Ok(r) => r,
426 Err(e) => format!("Error: {}", e),
427 };
428
429 self.request
431 .messages
432 .push(ChatMessage::tool_result(&tool_call.id, result_str));
433 }
434 }
435 }
436 }
437}
438
439pub struct ChatSession {
447 session_id: uuid::Uuid,
449 user_id: uuid::Uuid,
451 agent_id: uuid::Uuid,
453 tenant_id: uuid::Uuid,
455 client: LLMClient,
457 messages: Vec<ChatMessage>,
459 system_prompt: Option<String>,
461 tools: Vec<Tool>,
463 tool_executor: Option<Arc<dyn ToolExecutor>>,
465 created_at: std::time::Instant,
467 metadata: std::collections::HashMap<String, String>,
469 message_store: Arc<dyn crate::persistence::MessageStore>,
471 session_store: Arc<dyn crate::persistence::SessionStore>,
473 context_window_size: Option<usize>,
475 last_response_metadata: Option<super::types::LLMResponseMetadata>,
477}
478
479impl ChatSession {
480 pub fn new(client: LLMClient) -> Self {
482 let store = Arc::new(crate::persistence::InMemoryStore::new());
484 Self::with_id_and_stores(
485 Self::generate_session_id(),
486 client,
487 uuid::Uuid::now_v7(),
488 uuid::Uuid::now_v7(),
489 uuid::Uuid::now_v7(),
490 store.clone(),
491 store.clone(),
492 None,
493 )
494 }
495
496 pub fn new_with_stores(
498 client: LLMClient,
499 user_id: uuid::Uuid,
500 tenant_id: uuid::Uuid,
501 agent_id: uuid::Uuid,
502 message_store: Arc<dyn crate::persistence::MessageStore>,
503 session_store: Arc<dyn crate::persistence::SessionStore>,
504 ) -> Self {
505 Self::with_id_and_stores(
506 Self::generate_session_id(),
507 client,
508 user_id,
509 tenant_id,
510 agent_id,
511 message_store,
512 session_store,
513 None,
514 )
515 }
516
517 pub fn with_id(session_id: uuid::Uuid, client: LLMClient) -> Self {
519 let store = Arc::new(crate::persistence::InMemoryStore::new());
521 Self {
522 session_id,
523 user_id: uuid::Uuid::now_v7(),
524 agent_id: uuid::Uuid::now_v7(),
525 tenant_id: uuid::Uuid::now_v7(),
526 client,
527 messages: Vec::new(),
528 system_prompt: None,
529 tools: Vec::new(),
530 tool_executor: None,
531 created_at: std::time::Instant::now(),
532 metadata: std::collections::HashMap::new(),
533 message_store: store.clone(),
534 session_store: store.clone(),
535 context_window_size: None,
536 last_response_metadata: None,
537 }
538 }
539
540 pub fn with_id_str(session_id: &str, client: LLMClient) -> Self {
542 let session_id = uuid::Uuid::parse_str(session_id).unwrap_or_else(|_| uuid::Uuid::now_v7());
544 Self::with_id(session_id, client)
545 }
546
547 pub fn with_id_and_stores(
549 session_id: uuid::Uuid,
550 client: LLMClient,
551 user_id: uuid::Uuid,
552 tenant_id: uuid::Uuid,
553 agent_id: uuid::Uuid,
554 message_store: Arc<dyn crate::persistence::MessageStore>,
555 session_store: Arc<dyn crate::persistence::SessionStore>,
556 context_window_size: Option<usize>,
557 ) -> Self {
558 Self {
559 session_id,
560 user_id,
561 tenant_id,
562 agent_id,
563 client,
564 messages: Vec::new(),
565 system_prompt: None,
566 tools: Vec::new(),
567 tool_executor: None,
568 created_at: std::time::Instant::now(),
569 metadata: std::collections::HashMap::new(),
570 message_store,
571 session_store,
572 context_window_size,
573 last_response_metadata: None,
574 }
575 }
576
577 pub async fn with_id_and_stores_and_persist(
597 session_id: uuid::Uuid,
598 client: LLMClient,
599 user_id: uuid::Uuid,
600 tenant_id: uuid::Uuid,
601 agent_id: uuid::Uuid,
602 message_store: Arc<dyn crate::persistence::MessageStore>,
603 session_store: Arc<dyn crate::persistence::SessionStore>,
604 context_window_size: Option<usize>,
605 ) -> crate::persistence::PersistenceResult<Self> {
606 let session = Self::with_id_and_stores(
608 session_id,
609 client,
610 user_id,
611 tenant_id,
612 agent_id,
613 message_store,
614 session_store.clone(),
615 context_window_size,
616 );
617
618 let db_session =
620 crate::persistence::ChatSession::new(user_id, agent_id).with_id(session_id);
621 session_store.create_session(&db_session).await?;
622
623 Ok(session)
624 }
625
626 fn generate_session_id() -> uuid::Uuid {
628 uuid::Uuid::now_v7()
629 }
630
631 pub fn session_id(&self) -> uuid::Uuid {
633 self.session_id
634 }
635
636 pub fn session_id_str(&self) -> String {
638 self.session_id.to_string()
639 }
640
641 pub fn created_at(&self) -> std::time::Instant {
643 self.created_at
644 }
645
646 pub async fn load(
663 session_id: uuid::Uuid,
664 client: LLMClient,
665 user_id: uuid::Uuid,
666 tenant_id: uuid::Uuid,
667 agent_id: uuid::Uuid,
668 message_store: Arc<dyn crate::persistence::MessageStore>,
669 session_store: Arc<dyn crate::persistence::SessionStore>,
670 context_window_size: Option<usize>,
671 ) -> crate::persistence::PersistenceResult<Self> {
672 let _db_session = session_store
674 .get_session(session_id)
675 .await?
676 .ok_or_else(|| {
677 crate::persistence::PersistenceError::NotFound("Session not found".to_string())
678 })?;
679
680 let db_messages = if context_window_size.is_some() {
683 let total_count = message_store.count_session_messages(session_id).await?;
685
686 let rounds = context_window_size.unwrap_or(0);
688 let limit = (rounds * 2 + 20) as i64; let offset = std::cmp::max(0, total_count - limit);
692
693 message_store
694 .get_session_messages_paginated(session_id, offset, limit)
695 .await?
696 } else {
697 message_store.get_session_messages(session_id).await?
699 };
700
701 let mut messages = Vec::new();
703 for db_msg in db_messages {
704 let domain_role = match db_msg.role {
706 crate::persistence::MessageRole::System => crate::llm::types::Role::System,
707 crate::persistence::MessageRole::User => crate::llm::types::Role::User,
708 crate::persistence::MessageRole::Assistant => crate::llm::types::Role::Assistant,
709 crate::persistence::MessageRole::Tool => crate::llm::types::Role::Tool,
710 };
711
712 let domain_content = db_msg
714 .content
715 .text
716 .map(crate::llm::types::MessageContent::Text);
717
718 let domain_msg = ChatMessage {
720 role: domain_role,
721 content: domain_content,
722 name: None,
723 tool_calls: None,
724 tool_call_id: None,
725 };
726 messages.push(domain_msg);
727 }
728
729 let messages = Self::apply_sliding_window_static(&messages, context_window_size);
731
732 Ok(Self {
734 session_id,
735 user_id,
736 tenant_id,
737 agent_id,
738 client,
739 messages,
740 system_prompt: None, tools: Vec::new(), tool_executor: None, created_at: std::time::Instant::now(), metadata: std::collections::HashMap::new(), message_store,
746 session_store,
747 context_window_size,
748 last_response_metadata: None,
749 })
750 }
751
752 pub fn elapsed(&self) -> std::time::Duration {
754 self.created_at.elapsed()
755 }
756
757 pub fn set_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
759 self.metadata.insert(key.into(), value.into());
760 }
761
762 pub fn get_metadata(&self, key: &str) -> Option<&String> {
764 self.metadata.get(key)
765 }
766
767 pub fn metadata(&self) -> &std::collections::HashMap<String, String> {
769 &self.metadata
770 }
771
772 pub fn with_system(mut self, prompt: impl Into<String>) -> Self {
774 self.system_prompt = Some(prompt.into());
775 self
776 }
777
778 pub fn with_context_window_size(mut self, size: Option<usize>) -> Self {
795 self.context_window_size = size;
796 self
797 }
798
799 pub fn with_tools(mut self, tools: Vec<Tool>, executor: Arc<dyn ToolExecutor>) -> Self {
801 self.tools = tools;
802 self.tool_executor = Some(executor);
803 self
804 }
805
806 pub fn with_tool_executor(mut self, executor: Arc<dyn ToolExecutor>) -> Self {
808 self.tool_executor = Some(executor);
809 self
810 }
811
812 pub async fn send(&mut self, content: impl Into<String>) -> LLMResult<String> {
814 self.messages.push(ChatMessage::user(content));
816
817 let mut builder = self.client.chat();
819
820 if let Some(ref system) = self.system_prompt {
822 builder = builder.system(system.clone());
823 }
824
825 let messages_for_context = self.apply_sliding_window();
827 builder = builder.messages(messages_for_context);
828
829 if let Some(ref executor) = self.tool_executor {
831 let tools = if self.tools.is_empty() {
832 executor.available_tools().await?
833 } else {
834 self.tools.clone()
835 };
836
837 if !tools.is_empty() {
838 builder = builder.tools(tools);
839 }
840
841 builder = builder.with_tool_executor(executor.clone());
842 }
843
844 let response = if self.tool_executor.is_some() {
846 builder.send_with_tools().await?
847 } else {
848 builder.send().await?
849 };
850
851 self.last_response_metadata = Some(super::types::LLMResponseMetadata::from(&response));
853
854 let content = response
856 .content()
857 .ok_or_else(|| LLMError::Other("No content in response".to_string()))?
858 .to_string();
859
860 self.messages.push(ChatMessage::assistant(&content));
862
863 if self.context_window_size.is_some() {
865 self.messages =
866 Self::apply_sliding_window_static(&self.messages, self.context_window_size);
867 }
868
869 Ok(content)
870 }
871
872 pub async fn send_with_content(&mut self, content: MessageContent) -> LLMResult<String> {
874 self.messages.push(ChatMessage::user_with_content(content));
875
876 let mut builder = self.client.chat();
878
879 if let Some(ref system) = self.system_prompt {
881 builder = builder.system(system.clone());
882 }
883
884 let messages_for_context = self.apply_sliding_window();
886 builder = builder.messages(messages_for_context);
887
888 if let Some(ref executor) = self.tool_executor {
890 let tools = if self.tools.is_empty() {
891 executor.available_tools().await?
892 } else {
893 self.tools.clone()
894 };
895
896 if !tools.is_empty() {
897 builder = builder.tools(tools);
898 }
899
900 builder = builder.with_tool_executor(executor.clone());
901 }
902
903 let response = if self.tool_executor.is_some() {
905 builder.send_with_tools().await?
906 } else {
907 builder.send().await?
908 };
909
910 self.last_response_metadata = Some(super::types::LLMResponseMetadata::from(&response));
912
913 let content = response
915 .content()
916 .ok_or_else(|| LLMError::Other("No content in response".to_string()))?
917 .to_string();
918
919 self.messages.push(ChatMessage::assistant(&content));
921
922 if self.context_window_size.is_some() {
924 self.messages =
925 Self::apply_sliding_window_static(&self.messages, self.context_window_size);
926 }
927
928 Ok(content)
929 }
930
931 pub fn messages(&self) -> &[ChatMessage] {
933 &self.messages
934 }
935
936 pub fn messages_mut(&mut self) -> &mut Vec<ChatMessage> {
938 &mut self.messages
939 }
940
941 pub fn clear(&mut self) {
943 self.messages.clear();
944 }
945
946 pub fn len(&self) -> usize {
948 self.messages.len()
949 }
950
951 pub fn is_empty(&self) -> bool {
953 self.messages.is_empty()
954 }
955
956 pub fn set_context_window_size(&mut self, size: Option<usize>) {
973 self.context_window_size = size;
974 }
975
976 pub fn context_window_size(&self) -> Option<usize> {
978 self.context_window_size
979 }
980
981 pub fn last_response_metadata(&self) -> Option<&super::types::LLMResponseMetadata> {
983 self.last_response_metadata.as_ref()
984 }
985
986 fn apply_sliding_window(&self) -> Vec<ChatMessage> {
1009 Self::apply_sliding_window_static(&self.messages, self.context_window_size)
1010 }
1011
1012 pub fn apply_sliding_window_static(
1021 messages: &[ChatMessage],
1022 window_size: Option<usize>,
1023 ) -> Vec<ChatMessage> {
1024 let max_rounds = match window_size {
1025 Some(size) if size > 0 => size,
1026 _ => return messages.to_vec(), };
1028
1029 let mut system_messages = Vec::new();
1031 let mut conversation_messages = Vec::new();
1032
1033 for msg in messages {
1034 if msg.role == Role::System {
1035 system_messages.push(msg.clone());
1036 } else {
1037 conversation_messages.push(msg.clone());
1038 }
1039 }
1040
1041 let max_messages = max_rounds * 2;
1043
1044 if conversation_messages.len() <= max_messages {
1045 return messages.to_vec();
1047 }
1048
1049 let start_index = conversation_messages.len() - max_messages;
1051 let limited_conversation: Vec<ChatMessage> = conversation_messages
1052 .into_iter()
1053 .skip(start_index)
1054 .collect();
1055
1056 let mut result = system_messages;
1058 result.extend(limited_conversation);
1059
1060 result
1061 }
1062
1063 pub async fn save(&self) -> crate::persistence::PersistenceResult<()> {
1065 let db_session = crate::persistence::ChatSession::new(self.user_id, self.agent_id)
1067 .with_id(self.session_id)
1068 .with_metadata("client_version", serde_json::json!("0.1.0"));
1069
1070 self.session_store.create_session(&db_session).await?;
1072
1073 for msg in self.messages.iter() {
1075 let persistence_role = match msg.role {
1077 crate::llm::types::Role::System => crate::persistence::MessageRole::System,
1078 crate::llm::types::Role::User => crate::persistence::MessageRole::User,
1079 crate::llm::types::Role::Assistant => crate::persistence::MessageRole::Assistant,
1080 crate::llm::types::Role::Tool => crate::persistence::MessageRole::Tool,
1081 };
1082
1083 let persistence_content = match &msg.content {
1085 Some(crate::llm::types::MessageContent::Text(text)) => {
1086 crate::persistence::MessageContent::text(text)
1087 }
1088 Some(crate::llm::types::MessageContent::Parts(parts)) => {
1089 let text = parts
1091 .iter()
1092 .filter_map(|part| {
1093 if let crate::llm::types::ContentPart::Text { text } = part {
1094 Some(text.clone())
1095 } else {
1096 None
1097 }
1098 })
1099 .collect::<Vec<_>>()
1100 .join("\n");
1101 crate::persistence::MessageContent::text(text)
1102 }
1103 None => crate::persistence::MessageContent::text(""),
1104 };
1105
1106 let llm_message = crate::persistence::LLMMessage::new(
1107 self.session_id,
1108 self.agent_id,
1109 self.user_id,
1110 self.tenant_id,
1111 persistence_role,
1112 persistence_content,
1113 );
1114
1115 self.message_store.save_message(&llm_message).await?;
1117 }
1118
1119 Ok(())
1120 }
1121
1122 pub async fn delete(&self) -> crate::persistence::PersistenceResult<()> {
1124 self.message_store
1126 .delete_session_messages(self.session_id)
1127 .await?;
1128
1129 self.session_store.delete_session(self.session_id).await?;
1131
1132 Ok(())
1133 }
1134}
1135
1136pub fn function_tool(
1164 name: impl Into<String>,
1165 description: impl Into<String>,
1166 parameters: serde_json::Value,
1167) -> Tool {
1168 Tool::function(name, description, parameters)
1169}