1use super::client::{ChatSession, LLMClient};
26use super::provider::{ChatStream, LLMProvider};
27use super::tool_executor::ToolExecutor;
28use super::types::{ChatMessage, LLMError, LLMResult, Tool};
29use crate::llm::{
30 AnthropicConfig, AnthropicProvider, GeminiConfig, GeminiProvider, OllamaConfig, OllamaProvider,
31};
32use crate::prompt;
33use futures::{Stream, StreamExt};
34use mofa_kernel::agent::AgentMetadata;
35use mofa_kernel::agent::AgentState;
36use mofa_kernel::plugin::{AgentPlugin, PluginType};
37use mofa_plugins::tts::TTSPlugin;
38use std::collections::HashMap;
39use std::io::Write;
40use std::pin::Pin;
41use std::sync::Arc;
42use std::sync::atomic::{AtomicBool, Ordering};
43use std::time::Duration;
44use tokio::sync::{Mutex, RwLock};
45
46pub type TtsAudioStream = Pin<Box<dyn Stream<Item = (Vec<f32>, Duration)> + Send>>;
48
49struct CancellationToken {
51 cancel: Arc<AtomicBool>,
52}
53
54impl CancellationToken {
55 fn new() -> Self {
56 Self {
57 cancel: Arc::new(AtomicBool::new(false)),
58 }
59 }
60
61 fn is_cancelled(&self) -> bool {
62 self.cancel.load(Ordering::Relaxed)
63 }
64
65 fn cancel(&self) {
66 self.cancel.store(true, Ordering::Relaxed);
67 }
68
69 fn clone_token(&self) -> CancellationToken {
70 CancellationToken {
71 cancel: Arc::clone(&self.cancel),
72 }
73 }
74}
75
76pub type TextStream = Pin<Box<dyn Stream<Item = LLMResult<String>> + Send>>;
80
81#[cfg(feature = "kokoro")]
85struct TTSStreamHandle {
86 sink: mofa_plugins::tts::kokoro_wrapper::SynthSink<String>,
87 _stream_handle: tokio::task::JoinHandle<()>,
88}
89
90struct TTSSession {
92 cancellation_token: CancellationToken,
93 is_active: Arc<AtomicBool>,
94}
95
96impl TTSSession {
97 fn new(token: CancellationToken) -> Self {
98 let is_active = Arc::new(AtomicBool::new(true));
99 TTSSession {
100 cancellation_token: token,
101 is_active,
102 }
103 }
104
105 fn cancel(&self) {
106 self.cancellation_token.cancel();
107 self.is_active.store(false, Ordering::Relaxed);
108 }
109
110 fn is_active(&self) -> bool {
111 self.is_active.load(Ordering::Relaxed)
112 }
113}
114
115struct SentenceBuffer {
117 buffer: String,
118}
119
120impl SentenceBuffer {
121 fn new() -> Self {
122 Self {
123 buffer: String::new(),
124 }
125 }
126
127 fn push(&mut self, text: &str) -> Option<String> {
129 for ch in text.chars() {
130 self.buffer.push(ch);
131 if matches!(ch, '。' | '!' | '?' | '!' | '?') {
133 let sentence = self.buffer.trim().to_string();
134 if !sentence.is_empty() {
135 self.buffer.clear();
136 return Some(sentence);
137 }
138 }
139 }
140 None
141 }
142
143 fn flush(&mut self) -> Option<String> {
145 if self.buffer.trim().is_empty() {
146 None
147 } else {
148 let remaining = self.buffer.trim().to_string();
149 self.buffer.clear();
150 Some(remaining)
151 }
152 }
153}
154
155#[derive(Debug, Clone)]
157pub enum StreamEvent {
158 Text(String),
160 ToolCallStart { id: String, name: String },
162 ToolCallDelta { id: String, arguments_delta: String },
164 Done(Option<String>),
166}
167
168#[derive(Clone)]
170pub struct LLMAgentConfig {
171 pub agent_id: String,
173 pub name: String,
175 pub system_prompt: Option<String>,
177 pub temperature: Option<f32>,
179 pub max_tokens: Option<u32>,
181 pub custom_config: HashMap<String, String>,
183 pub user_id: Option<String>,
185 pub tenant_id: Option<String>,
187 pub context_window_size: Option<usize>,
192}
193
194impl Default for LLMAgentConfig {
195 fn default() -> Self {
196 Self {
197 agent_id: "llm-agent".to_string(),
198 name: "LLM Agent".to_string(),
199 system_prompt: None,
200 temperature: Some(0.7),
201 max_tokens: Some(4096),
202 custom_config: HashMap::new(),
203 user_id: None,
204 tenant_id: None,
205 context_window_size: None,
206 }
207 }
208}
209
210pub struct LLMAgent {
258 config: LLMAgentConfig,
259 metadata: AgentMetadata,
261 client: LLMClient,
262 sessions: Arc<RwLock<HashMap<String, Arc<RwLock<ChatSession>>>>>,
264 active_session_id: Arc<RwLock<String>>,
266 tools: Vec<Tool>,
267 tool_executor: Option<Arc<dyn ToolExecutor>>,
268 event_handler: Option<Box<dyn LLMAgentEventHandler>>,
269 plugins: Vec<Box<dyn AgentPlugin>>,
271 state: AgentState,
273 provider: Arc<dyn LLMProvider>,
275 prompt_plugin: Option<Box<dyn prompt::PromptTemplatePlugin>>,
277 tts_plugin: Option<Arc<Mutex<TTSPlugin>>>,
279 #[cfg(feature = "kokoro")]
281 cached_kokoro_engine: Arc<Mutex<Option<Arc<mofa_plugins::tts::kokoro_wrapper::KokoroTTS>>>>,
282 active_tts_session: Arc<Mutex<Option<TTSSession>>>,
284 message_store: Option<Arc<dyn crate::persistence::MessageStore + Send + Sync>>,
286 session_store: Option<Arc<dyn crate::persistence::SessionStore + Send + Sync>>,
287 persistence_user_id: Option<uuid::Uuid>,
289 persistence_agent_id: Option<uuid::Uuid>,
291}
292
293#[async_trait::async_trait]
297pub trait LLMAgentEventHandler: Send + Sync {
298 fn clone_box(&self) -> Box<dyn LLMAgentEventHandler>;
300
301 fn as_any(&self) -> &dyn std::any::Any;
303
304 async fn before_chat(&self, message: &str) -> LLMResult<Option<String>> {
306 Ok(Some(message.to_string()))
307 }
308
309 async fn before_chat_with_model(
314 &self,
315 message: &str,
316 _model: &str,
317 ) -> LLMResult<Option<String>> {
318 self.before_chat(message).await
319 }
320
321 async fn after_chat(&self, response: &str) -> LLMResult<Option<String>> {
323 Ok(Some(response.to_string()))
324 }
325
326 async fn after_chat_with_metadata(
331 &self,
332 response: &str,
333 _metadata: &super::types::LLMResponseMetadata,
334 ) -> LLMResult<Option<String>> {
335 self.after_chat(response).await
336 }
337
338 async fn on_tool_call(&self, name: &str, arguments: &str) -> LLMResult<Option<String>> {
340 let _ = (name, arguments);
341 Ok(None)
342 }
343
344 async fn on_error(&self, error: &LLMError) -> LLMResult<Option<String>> {
346 let _ = error;
347 Ok(None)
348 }
349}
350
351impl Clone for Box<dyn LLMAgentEventHandler> {
352 fn clone(&self) -> Self {
353 self.clone_box()
354 }
355}
356
357impl LLMAgent {
358 pub fn new(config: LLMAgentConfig, provider: Arc<dyn LLMProvider>) -> Self {
360 Self::with_initial_session(config, provider, None)
361 }
362
363 pub fn with_initial_session(
380 config: LLMAgentConfig,
381 provider: Arc<dyn LLMProvider>,
382 initial_session_id: Option<String>,
383 ) -> Self {
384 let client = LLMClient::new(provider.clone());
385
386 let mut session = if let Some(sid) = initial_session_id {
387 ChatSession::with_id_str(&sid, LLMClient::new(provider.clone()))
388 } else {
389 ChatSession::new(LLMClient::new(provider.clone()))
390 };
391
392 if let Some(ref prompt) = config.system_prompt {
394 session = session.with_system(prompt.clone());
395 }
396
397 session = session.with_context_window_size(config.context_window_size);
399
400 let session_id = session.session_id().to_string();
401 let session_arc = Arc::new(RwLock::new(session));
402
403 let mut sessions = HashMap::new();
405 sessions.insert(session_id.clone(), session_arc);
406
407 let agent_id = config.agent_id.clone();
409 let name = config.name.clone();
410
411 let capabilities = mofa_kernel::agent::AgentCapabilities::builder()
413 .tags(vec![
414 "llm".to_string(),
415 "chat".to_string(),
416 "text-generation".to_string(),
417 "multi-session".to_string(),
418 ])
419 .build();
420
421 Self {
422 config,
423 metadata: AgentMetadata {
424 id: agent_id,
425 name,
426 description: None,
427 version: None,
428 capabilities,
429 state: AgentState::Created,
430 },
431 client,
432 sessions: Arc::new(RwLock::new(sessions)),
433 active_session_id: Arc::new(RwLock::new(session_id)),
434 tools: Vec::new(),
435 tool_executor: None,
436 event_handler: None,
437 plugins: Vec::new(),
438 state: AgentState::Created,
439 provider,
440 prompt_plugin: None,
441 tts_plugin: None,
442 #[cfg(feature = "kokoro")]
443 cached_kokoro_engine: Arc::new(Mutex::new(None)),
444 active_tts_session: Arc::new(Mutex::new(None)),
445 message_store: None,
446 session_store: None,
447 persistence_user_id: None,
448 persistence_agent_id: None,
449 }
450 }
451
452 pub async fn with_initial_session_async(
480 config: LLMAgentConfig,
481 provider: Arc<dyn LLMProvider>,
482 initial_session_id: Option<String>,
483 message_store: Option<Arc<dyn crate::persistence::MessageStore + Send + Sync>>,
484 session_store: Option<Arc<dyn crate::persistence::SessionStore + Send + Sync>>,
485 persistence_user_id: Option<uuid::Uuid>,
486 persistence_tenant_id: Option<uuid::Uuid>,
487 persistence_agent_id: Option<uuid::Uuid>,
488 ) -> Self {
489 let client = LLMClient::new(provider.clone());
490
491 let initial_session_id_clone = initial_session_id.clone();
493
494 let session = if let (
496 Some(sid),
497 Some(msg_store),
498 Some(sess_store),
499 Some(user_id),
500 Some(tenant_id),
501 Some(agent_id),
502 ) = (
503 initial_session_id_clone,
504 message_store.clone(),
505 session_store.clone(),
506 persistence_user_id,
507 persistence_tenant_id,
508 persistence_agent_id,
509 ) {
510 let msg_store_clone = msg_store.clone();
512 let sess_store_clone = sess_store.clone();
513
514 let session_uuid = uuid::Uuid::parse_str(&sid).unwrap_or_else(|_| {
515 tracing::warn!("⚠️ 无效的 session_id 格式 '{}', 将生成新的 UUID", sid);
516 uuid::Uuid::now_v7()
517 });
518
519 match ChatSession::load(
521 session_uuid,
522 LLMClient::new(provider.clone()),
523 user_id,
524 agent_id,
525 tenant_id,
526 msg_store,
527 sess_store,
528 config.context_window_size,
529 )
530 .await
531 {
532 Ok(loaded_session) => {
533 tracing::info!(
534 "✅ 从数据库加载会话: {} ({} 条消息)",
535 sid,
536 loaded_session.messages().len()
537 );
538 loaded_session
539 }
540 Err(e) => {
541 tracing::info!("📝 创建新会话并持久化: {} (数据库中不存在: {})", sid, e);
543
544 let msg_store_clone2 = msg_store_clone.clone();
546 let sess_store_clone2 = sess_store_clone.clone();
547
548 match ChatSession::with_id_and_stores_and_persist(
550 session_uuid,
551 LLMClient::new(provider.clone()),
552 user_id,
553 agent_id,
554 tenant_id,
555 msg_store_clone,
556 sess_store_clone,
557 config.context_window_size,
558 )
559 .await
560 {
561 Ok(mut new_session) => {
562 if let Some(ref prompt) = config.system_prompt {
563 new_session = new_session.with_system(prompt.clone());
564 }
565 new_session
566 }
567 Err(persist_err) => {
568 tracing::error!("❌ 持久化会话失败: {}, 降级为内存会话", persist_err);
569 let new_session = ChatSession::with_id_and_stores(
571 session_uuid,
572 LLMClient::new(provider.clone()),
573 user_id,
574 agent_id,
575 tenant_id,
576 msg_store_clone2,
577 sess_store_clone2,
578 config.context_window_size,
579 );
580 if let Some(ref prompt) = config.system_prompt {
581 new_session.with_system(prompt.clone())
582 } else {
583 new_session
584 }
585 }
586 }
587 }
588 }
589 } else {
590 let mut session = if let Some(sid) = initial_session_id {
592 ChatSession::with_id_str(&sid, LLMClient::new(provider.clone()))
593 } else {
594 ChatSession::new(LLMClient::new(provider.clone()))
595 };
596 if let Some(ref prompt) = config.system_prompt {
597 session = session.with_system(prompt.clone());
598 }
599 session.with_context_window_size(config.context_window_size)
600 };
601
602 let session_id = session.session_id().to_string();
603 let session_arc = Arc::new(RwLock::new(session));
604
605 let mut sessions = HashMap::new();
607 sessions.insert(session_id.clone(), session_arc);
608
609 let agent_id = config.agent_id.clone();
611 let name = config.name.clone();
612
613 let capabilities = mofa_kernel::agent::AgentCapabilities::builder()
615 .tags(vec![
616 "llm".to_string(),
617 "chat".to_string(),
618 "text-generation".to_string(),
619 "multi-session".to_string(),
620 ])
621 .build();
622
623 Self {
624 config,
625 metadata: AgentMetadata {
626 id: agent_id,
627 name,
628 description: None,
629 version: None,
630 capabilities,
631 state: AgentState::Created,
632 },
633 client,
634 sessions: Arc::new(RwLock::new(sessions)),
635 active_session_id: Arc::new(RwLock::new(session_id)),
636 tools: Vec::new(),
637 tool_executor: None,
638 event_handler: None,
639 plugins: Vec::new(),
640 state: AgentState::Created,
641 provider,
642 prompt_plugin: None,
643 tts_plugin: None,
644 #[cfg(feature = "kokoro")]
645 cached_kokoro_engine: Arc::new(Mutex::new(None)),
646 active_tts_session: Arc::new(Mutex::new(None)),
647 message_store,
648 session_store,
649 persistence_user_id,
650 persistence_agent_id,
651 }
652 }
653
654 pub fn config(&self) -> &LLMAgentConfig {
656 &self.config
657 }
658
659 pub fn client(&self) -> &LLMClient {
661 &self.client
662 }
663
664 pub async fn current_session_id(&self) -> String {
670 self.active_session_id.read().await.clone()
671 }
672
673 pub async fn create_session(&self) -> String {
684 let mut session = ChatSession::new(LLMClient::new(self.provider.clone()));
685
686 let mut system_prompt = self.config.system_prompt.clone();
688
689 if let Some(ref plugin) = self.prompt_plugin
690 && let Some(template) = plugin.get_current_template().await
691 {
692 system_prompt = match template.render(&[]) {
694 Ok(prompt) => Some(prompt),
695 Err(_) => self.config.system_prompt.clone(),
696 };
697 }
698
699 if let Some(ref prompt) = system_prompt {
700 session = session.with_system(prompt.clone());
701 }
702
703 session = session.with_context_window_size(self.config.context_window_size);
705
706 let session_id = session.session_id().to_string();
707 let session_arc = Arc::new(RwLock::new(session));
708
709 let mut sessions = self.sessions.write().await;
710 sessions.insert(session_id.clone(), session_arc);
711
712 session_id
713 }
714
715 pub async fn create_session_with_id(&self, session_id: impl Into<String>) -> LLMResult<String> {
725 let session_id = session_id.into();
726
727 {
728 let sessions = self.sessions.read().await;
729 if sessions.contains_key(&session_id) {
730 return Err(LLMError::Other(format!(
731 "Session with id '{}' already exists",
732 session_id
733 )));
734 }
735 }
736
737 let mut session =
738 ChatSession::with_id_str(&session_id, LLMClient::new(self.provider.clone()));
739
740 let mut system_prompt = self.config.system_prompt.clone();
742
743 if let Some(ref plugin) = self.prompt_plugin
744 && let Some(template) = plugin.get_current_template().await
745 {
746 system_prompt = match template.render(&[]) {
748 Ok(prompt) => Some(prompt),
749 Err(_) => self.config.system_prompt.clone(),
750 };
751 }
752
753 if let Some(ref prompt) = system_prompt {
754 session = session.with_system(prompt.clone());
755 }
756
757 session = session.with_context_window_size(self.config.context_window_size);
759
760 let session_arc = Arc::new(RwLock::new(session));
761
762 let mut sessions = self.sessions.write().await;
763 sessions.insert(session_id.clone(), session_arc);
764
765 Ok(session_id)
766 }
767
768 pub async fn switch_session(&self, session_id: &str) -> LLMResult<()> {
773 let sessions = self.sessions.read().await;
774 if !sessions.contains_key(session_id) {
775 return Err(LLMError::Other(format!(
776 "Session '{}' not found",
777 session_id
778 )));
779 }
780 drop(sessions);
781
782 let mut active = self.active_session_id.write().await;
783 *active = session_id.to_string();
784 Ok(())
785 }
786
787 pub async fn get_or_create_session(&self, session_id: impl Into<String>) -> String {
791 let session_id = session_id.into();
792
793 {
794 let sessions = self.sessions.read().await;
795 if sessions.contains_key(&session_id) {
796 return session_id;
797 }
798 }
799
800 let _ = self.create_session_with_id(&session_id).await;
802 session_id
803 }
804
805 pub async fn remove_session(&self, session_id: &str) -> LLMResult<()> {
810 let active = self.active_session_id.read().await.clone();
811 if active == session_id {
812 return Err(LLMError::Other(
813 "Cannot remove active session. Switch to another session first.".to_string(),
814 ));
815 }
816
817 let mut sessions = self.sessions.write().await;
818 if sessions.remove(session_id).is_none() {
819 return Err(LLMError::Other(format!(
820 "Session '{}' not found",
821 session_id
822 )));
823 }
824
825 Ok(())
826 }
827
828 pub async fn list_sessions(&self) -> Vec<String> {
830 let sessions = self.sessions.read().await;
831 sessions.keys().cloned().collect()
832 }
833
834 pub async fn session_count(&self) -> usize {
836 let sessions = self.sessions.read().await;
837 sessions.len()
838 }
839
840 pub async fn has_session(&self, session_id: &str) -> bool {
842 let sessions = self.sessions.read().await;
843 sessions.contains_key(session_id)
844 }
845
846 pub async fn tts_speak(&self, text: &str) -> LLMResult<()> {
858 let tts = self
859 .tts_plugin
860 .as_ref()
861 .ok_or_else(|| LLMError::Other("TTS plugin not configured".to_string()))?;
862
863 let mut tts_guard = tts.lock().await;
864 tts_guard
865 .synthesize_and_play(text)
866 .await
867 .map_err(|e| LLMError::Other(format!("TTS synthesis failed: {}", e)))
868 }
869
870 pub async fn tts_speak_streaming(
880 &self,
881 text: &str,
882 callback: Box<dyn Fn(Vec<u8>) + Send + Sync>,
883 ) -> LLMResult<()> {
884 let tts = self
885 .tts_plugin
886 .as_ref()
887 .ok_or_else(|| LLMError::Other("TTS plugin not configured".to_string()))?;
888
889 let mut tts_guard = tts.lock().await;
890 tts_guard
891 .synthesize_streaming(text, callback)
892 .await
893 .map_err(|e| LLMError::Other(format!("TTS streaming failed: {}", e)))
894 }
895
896 pub async fn tts_speak_f32_stream(
912 &self,
913 text: &str,
914 callback: Box<dyn Fn(Vec<f32>) + Send + Sync>,
915 ) -> LLMResult<()> {
916 let tts = self
917 .tts_plugin
918 .as_ref()
919 .ok_or_else(|| LLMError::Other("TTS plugin not configured".to_string()))?;
920
921 let mut tts_guard = tts.lock().await;
922 tts_guard
923 .synthesize_streaming_f32(text, callback)
924 .await
925 .map_err(|e| LLMError::Other(format!("TTS f32 streaming failed: {}", e)))
926 }
927
928 pub async fn tts_create_stream(&self, text: &str) -> LLMResult<TtsAudioStream> {
946 #[cfg(feature = "kokoro")]
947 {
948 use mofa_plugins::tts::kokoro_wrapper::KokoroTTS;
949
950 let cached_engine = {
952 let cache_guard = self.cached_kokoro_engine.lock().await;
953 cache_guard.clone()
954 };
955
956 let kokoro = if let Some(engine) = cached_engine {
957 engine
959 } else {
960 let tts = self
962 .tts_plugin
963 .as_ref()
964 .ok_or_else(|| LLMError::Other("TTS plugin not configured".to_string()))?;
965
966 let tts_guard = tts.lock().await;
967
968 let engine = tts_guard
969 .engine()
970 .ok_or_else(|| LLMError::Other("TTS engine not initialized".to_string()))?;
971
972 if let Some(kokoro_ref) = engine.as_any().downcast_ref::<KokoroTTS>() {
973 let cloned = kokoro_ref.clone();
975 let cloned_arc = Arc::new(cloned);
976
977 let voice = tts_guard
979 .stats()
980 .get("default_voice")
981 .and_then(|v| v.as_str())
982 .unwrap_or("default");
983
984 {
986 let mut cache_guard = self.cached_kokoro_engine.lock().await;
987 *cache_guard = Some(cloned_arc.clone());
988 }
989
990 cloned_arc
991 } else {
992 return Err(LLMError::Other("TTS engine is not KokoroTTS".to_string()));
993 }
994 };
995
996 let voice = "default"; let (mut sink, stream) = kokoro
999 .create_stream(voice)
1000 .await
1001 .map_err(|e| LLMError::Other(format!("Failed to create TTS stream: {}", e)))?;
1002
1003 sink.synth(text.to_string()).await.map_err(|e| {
1005 LLMError::Other(format!("Failed to submit text for synthesis: {}", e))
1006 })?;
1007
1008 return Ok(Box::pin(stream));
1010 }
1011
1012 #[cfg(not(feature = "kokoro"))]
1013 {
1014 Err(LLMError::Other("Kokoro feature not enabled".to_string()))
1015 }
1016 }
1017
1018 pub async fn tts_speak_f32_stream_batch(
1041 &self,
1042 sentences: Vec<String>,
1043 callback: Box<dyn Fn(Vec<f32>) + Send + Sync>,
1044 ) -> LLMResult<()> {
1045 let tts = self
1046 .tts_plugin
1047 .as_ref()
1048 .ok_or_else(|| LLMError::Other("TTS plugin not configured".to_string()))?;
1049
1050 let tts_guard = tts.lock().await;
1051
1052 #[cfg(feature = "kokoro")]
1053 {
1054 use mofa_plugins::tts::kokoro_wrapper::KokoroTTS;
1055
1056 let engine = tts_guard
1057 .engine()
1058 .ok_or_else(|| LLMError::Other("TTS engine not initialized".to_string()))?;
1059
1060 if let Some(kokoro) = engine.as_any().downcast_ref::<KokoroTTS>() {
1061 let voice = tts_guard
1062 .stats()
1063 .get("default_voice")
1064 .and_then(|v| v.as_str())
1065 .unwrap_or("default")
1066 .to_string();
1067
1068 let (mut sink, mut stream) = kokoro
1070 .create_stream(&voice)
1071 .await
1072 .map_err(|e| LLMError::Other(format!("Failed to create TTS stream: {}", e)))?;
1073
1074 tokio::spawn(async move {
1076 while let Some((audio, _took)) = stream.next().await {
1077 callback(audio);
1078 }
1079 });
1080
1081 for sentence in sentences {
1083 sink.synth(sentence)
1084 .await
1085 .map_err(|e| LLMError::Other(format!("Failed to submit text: {}", e)))?;
1086 }
1087
1088 return Ok(());
1089 }
1090
1091 return Err(LLMError::Other("TTS engine is not KokoroTTS".to_string()));
1092 }
1093
1094 #[cfg(not(feature = "kokoro"))]
1095 {
1096 Err(LLMError::Other("Kokoro feature not enabled".to_string()))
1097 }
1098 }
1099
1100 pub fn has_tts(&self) -> bool {
1102 self.tts_plugin.is_some()
1103 }
1104
1105 pub async fn interrupt_tts(&self) -> LLMResult<()> {
1117 let mut session_guard = self.active_tts_session.lock().await;
1118 if let Some(session) = session_guard.take() {
1119 session.cancel();
1120 }
1121 Ok(())
1122 }
1123
1124 pub async fn chat_with_tts(
1140 &self,
1141 session_id: &str,
1142 message: impl Into<String>,
1143 ) -> LLMResult<()> {
1144 self.chat_with_tts_internal(session_id, message, None).await
1145 }
1146
1147 pub async fn chat_with_tts_callback(
1158 &self,
1159 session_id: &str,
1160 message: impl Into<String>,
1161 callback: impl Fn(Vec<f32>) + Send + Sync + 'static,
1162 ) -> LLMResult<()> {
1163 self.chat_with_tts_internal(session_id, message, Some(Box::new(callback)))
1164 .await
1165 }
1166
1167 #[cfg(feature = "kokoro")]
1176 async fn create_tts_stream_handle(
1177 &self,
1178 callback: Box<dyn Fn(Vec<f32>) + Send + Sync>,
1179 cancellation_token: Option<CancellationToken>,
1180 ) -> LLMResult<TTSStreamHandle> {
1181 use mofa_plugins::tts::kokoro_wrapper::KokoroTTS;
1182
1183 let tts = self
1184 .tts_plugin
1185 .as_ref()
1186 .ok_or_else(|| LLMError::Other("TTS plugin not configured".to_string()))?;
1187
1188 let tts_guard = tts.lock().await;
1189 let engine = tts_guard
1190 .engine()
1191 .ok_or_else(|| LLMError::Other("TTS engine not initialized".to_string()))?;
1192
1193 let kokoro = engine
1194 .as_any()
1195 .downcast_ref::<KokoroTTS>()
1196 .ok_or_else(|| LLMError::Other("TTS engine is not KokoroTTS".to_string()))?;
1197
1198 let voice = tts_guard
1199 .stats()
1200 .get("default_voice")
1201 .and_then(|v| v.as_str())
1202 .unwrap_or("default")
1203 .to_string();
1204
1205 let (sink, mut stream) = kokoro
1207 .create_stream(&voice)
1208 .await
1209 .map_err(|e| LLMError::Other(format!("Failed to create TTS stream: {}", e)))?;
1210
1211 let token_clone = cancellation_token.as_ref().map(|t| t.clone_token());
1213
1214 let stream_handle = tokio::spawn(async move {
1216 while let Some((audio, _took)) = stream.next().await {
1217 if let Some(ref token) = token_clone {
1219 if token.is_cancelled() {
1220 break; }
1222 }
1223 callback(audio);
1224 }
1225 });
1226
1227 Ok(TTSStreamHandle {
1228 sink,
1229 _stream_handle: stream_handle,
1230 })
1231 }
1232
1233 async fn chat_with_tts_internal(
1240 &self,
1241 session_id: &str,
1242 message: impl Into<String>,
1243 callback: Option<Box<dyn Fn(Vec<f32>) + Send + Sync>>,
1244 ) -> LLMResult<()> {
1245 #[cfg(feature = "kokoro")]
1246 {
1247 use mofa_plugins::tts::kokoro_wrapper::KokoroTTS;
1248
1249 let callback = match callback {
1250 Some(cb) => cb,
1251 None => {
1252 let mut text_stream =
1254 self.chat_stream_with_session(session_id, message).await?;
1255 while let Some(result) = text_stream.next().await {
1256 match result {
1257 Ok(text_chunk) => {
1258 print!("{}", text_chunk);
1259 std::io::stdout().flush().map_err(|e| {
1260 LLMError::Other(format!("Failed to flush stdout: {}", e))
1261 })?;
1262 }
1263 Err(e) if e.to_string().contains("__stream_end__") => break,
1264 Err(e) => return Err(e),
1265 }
1266 }
1267 println!();
1268 return Ok(());
1269 }
1270 };
1271
1272 self.interrupt_tts().await?;
1274
1275 let cancellation_token = CancellationToken::new();
1277
1278 let mut tts_handle = self
1280 .create_tts_stream_handle(callback, Some(cancellation_token.clone_token()))
1281 .await?;
1282
1283 let session = TTSSession::new(cancellation_token);
1285
1286 {
1287 let mut active_session = self.active_tts_session.lock().await;
1288 *active_session = Some(session);
1289 }
1290
1291 let mut buffer = SentenceBuffer::new();
1292
1293 let mut text_stream = self.chat_stream_with_session(session_id, message).await?;
1295
1296 while let Some(result) = text_stream.next().await {
1297 match result {
1298 Ok(text_chunk) => {
1299 {
1301 let active_session = self.active_tts_session.lock().await;
1302 if let Some(ref session) = *active_session {
1303 if !session.is_active() {
1304 return Ok(()); }
1306 }
1307 }
1308
1309 print!("{}", text_chunk);
1311 std::io::stdout().flush().map_err(|e| {
1312 LLMError::Other(format!("Failed to flush stdout: {}", e))
1313 })?;
1314
1315 if let Some(sentence) = buffer.push(&text_chunk) {
1317 if let Err(e) = tts_handle.sink.synth(sentence).await {
1318 eprintln!("[TTS Error] Failed to submit sentence: {}", e);
1319 }
1321 }
1322 }
1323 Err(e) if e.to_string().contains("__stream_end__") => break,
1324 Err(e) => return Err(e),
1325 }
1326 }
1327
1328 if let Some(remaining) = buffer.flush() {
1330 if let Err(e) = tts_handle.sink.synth(remaining).await {
1331 eprintln!("[TTS Error] Failed to submit final sentence: {}", e);
1332 }
1333 }
1334
1335 {
1337 let mut active_session = self.active_tts_session.lock().await;
1338 *active_session = None;
1339 }
1340
1341 let _ = tokio::time::timeout(
1343 tokio::time::Duration::from_secs(30),
1344 tts_handle._stream_handle,
1345 )
1346 .await
1347 .map_err(|_| LLMError::Other("TTS stream processing timeout".to_string()))
1348 .and_then(|r| r.map_err(|e| LLMError::Other(format!("TTS stream task failed: {}", e))));
1349
1350 Ok(())
1351 }
1352
1353 #[cfg(not(feature = "kokoro"))]
1354 {
1355 let mut text_stream = self.chat_stream_with_session(session_id, message).await?;
1357 let mut buffer = SentenceBuffer::new();
1358 let mut sentences = Vec::new();
1359
1360 while let Some(result) = text_stream.next().await {
1362 match result {
1363 Ok(text_chunk) => {
1364 print!("{}", text_chunk);
1365 std::io::stdout().flush().map_err(|e| {
1366 LLMError::Other(format!("Failed to flush stdout: {}", e))
1367 })?;
1368
1369 if let Some(sentence) = buffer.push(&text_chunk) {
1370 sentences.push(sentence);
1371 }
1372 }
1373 Err(e) if e.to_string().contains("__stream_end__") => break,
1374 Err(e) => return Err(e),
1375 }
1376 }
1377
1378 if let Some(remaining) = buffer.flush() {
1380 sentences.push(remaining);
1381 }
1382
1383 if !sentences.is_empty()
1385 && let Some(cb) = callback
1386 {
1387 for sentence in &sentences {
1388 println!("\n[TTS] {}", sentence);
1389 }
1390 let _ = cb;
1393 }
1394
1395 Ok(())
1396 }
1397 }
1398
1399 async fn get_session_arc(&self, session_id: &str) -> LLMResult<Arc<RwLock<ChatSession>>> {
1401 let sessions = self.sessions.read().await;
1402 sessions
1403 .get(session_id)
1404 .cloned()
1405 .ok_or_else(|| LLMError::Other(format!("Session '{}' not found", session_id)))
1406 }
1407
1408 pub async fn chat(&self, message: impl Into<String>) -> LLMResult<String> {
1414 let session_id = self.active_session_id.read().await.clone();
1415 self.chat_with_session(&session_id, message).await
1416 }
1417
1418 pub async fn chat_with_session(
1431 &self,
1432 session_id: &str,
1433 message: impl Into<String>,
1434 ) -> LLMResult<String> {
1435 let message = message.into();
1436
1437 let model = self.provider.default_model();
1439
1440 let processed_message = if let Some(ref handler) = self.event_handler {
1442 match handler.before_chat_with_model(&message, model).await? {
1443 Some(msg) => msg,
1444 None => return Ok(String::new()),
1445 }
1446 } else {
1447 message
1448 };
1449
1450 let session = self.get_session_arc(session_id).await?;
1452
1453 let mut session_guard = session.write().await;
1455 let response = match session_guard.send(&processed_message).await {
1456 Ok(resp) => resp,
1457 Err(e) => {
1458 if let Some(ref handler) = self.event_handler
1459 && let Some(fallback) = handler.on_error(&e).await?
1460 {
1461 return Ok(fallback);
1462 }
1463 return Err(e);
1464 }
1465 };
1466
1467 let final_response = if let Some(ref handler) = self.event_handler {
1469 let metadata = session_guard.last_response_metadata();
1471 if let Some(meta) = metadata {
1472 match handler.after_chat_with_metadata(&response, meta).await? {
1473 Some(resp) => resp,
1474 None => response,
1475 }
1476 } else {
1477 match handler.after_chat(&response).await? {
1479 Some(resp) => resp,
1480 None => response,
1481 }
1482 }
1483 } else {
1484 response
1485 };
1486
1487 Ok(final_response)
1488 }
1489
1490 pub async fn ask(&self, question: impl Into<String>) -> LLMResult<String> {
1492 let question = question.into();
1493
1494 let mut builder = self.client.chat();
1495
1496 let mut system_prompt = self.config.system_prompt.clone();
1498
1499 if let Some(ref plugin) = self.prompt_plugin
1500 && let Some(template) = plugin.get_current_template().await
1501 {
1502 match template.render(&[]) {
1504 Ok(prompt) => system_prompt = Some(prompt),
1505 Err(_) => {
1506 system_prompt = self.config.system_prompt.clone();
1508 }
1509 }
1510 }
1511
1512 if let Some(ref system) = system_prompt {
1514 builder = builder.system(system.clone());
1515 }
1516
1517 if let Some(temp) = self.config.temperature {
1518 builder = builder.temperature(temp);
1519 }
1520
1521 if let Some(tokens) = self.config.max_tokens {
1522 builder = builder.max_tokens(tokens);
1523 }
1524
1525 builder = builder.user(question);
1526
1527 if let Some(ref executor) = self.tool_executor {
1529 let tools = if self.tools.is_empty() {
1530 executor.available_tools().await?
1531 } else {
1532 self.tools.clone()
1533 };
1534
1535 if !tools.is_empty() {
1536 builder = builder.tools(tools);
1537 }
1538
1539 builder = builder.with_tool_executor(executor.clone());
1540 let response = builder.send_with_tools().await?;
1541 return response
1542 .content()
1543 .map(|s| s.to_string())
1544 .ok_or_else(|| LLMError::Other("No content in response".to_string()));
1545 }
1546
1547 let response = builder.send().await?;
1548 response
1549 .content()
1550 .map(|s| s.to_string())
1551 .ok_or_else(|| LLMError::Other("No content in response".to_string()))
1552 }
1553
1554 pub async fn set_prompt_scenario(&self, scenario: impl Into<String>) {
1556 let scenario = scenario.into();
1557
1558 if let Some(ref plugin) = self.prompt_plugin {
1559 plugin.set_active_scenario(&scenario).await;
1560 }
1561 }
1562
1563 pub async fn clear_history(&self) {
1565 let session_id = self.active_session_id.read().await.clone();
1566 let _ = self.clear_session_history(&session_id).await;
1567 }
1568
1569 pub async fn clear_session_history(&self, session_id: &str) -> LLMResult<()> {
1571 let session = self.get_session_arc(session_id).await?;
1572 let mut session_guard = session.write().await;
1573 session_guard.clear();
1574 Ok(())
1575 }
1576
1577 pub async fn history(&self) -> Vec<ChatMessage> {
1579 let session_id = self.active_session_id.read().await.clone();
1580 self.get_session_history(&session_id)
1581 .await
1582 .unwrap_or_default()
1583 }
1584
1585 pub async fn get_session_history(&self, session_id: &str) -> LLMResult<Vec<ChatMessage>> {
1587 let session = self.get_session_arc(session_id).await?;
1588 let session_guard = session.read().await;
1589 Ok(session_guard.messages().to_vec())
1590 }
1591
1592 pub fn set_tools(&mut self, tools: Vec<Tool>, executor: Arc<dyn ToolExecutor>) {
1594 self.tools = tools;
1595 self.tool_executor = Some(executor);
1596
1597 }
1600
1601 pub fn set_event_handler(&mut self, handler: Box<dyn LLMAgentEventHandler>) {
1603 self.event_handler = Some(handler);
1604 }
1605
1606 pub fn add_plugin<P: AgentPlugin + 'static>(&mut self, plugin: P) {
1608 self.plugins.push(Box::new(plugin));
1609 }
1610
1611 pub fn add_plugins(&mut self, plugins: Vec<Box<dyn AgentPlugin>>) {
1613 self.plugins.extend(plugins);
1614 }
1615
1616 pub async fn ask_stream(&self, question: impl Into<String>) -> LLMResult<TextStream> {
1638 let question = question.into();
1639
1640 let mut builder = self.client.chat();
1641
1642 if let Some(ref system) = self.config.system_prompt {
1643 builder = builder.system(system.clone());
1644 }
1645
1646 if let Some(temp) = self.config.temperature {
1647 builder = builder.temperature(temp);
1648 }
1649
1650 if let Some(tokens) = self.config.max_tokens {
1651 builder = builder.max_tokens(tokens);
1652 }
1653
1654 builder = builder.user(question);
1655
1656 let chunk_stream = builder.send_stream().await?;
1658
1659 Ok(Self::chunk_stream_to_text_stream(chunk_stream))
1661 }
1662
1663 pub async fn chat_stream(&self, message: impl Into<String>) -> LLMResult<TextStream> {
1686 let session_id = self.active_session_id.read().await.clone();
1687 self.chat_stream_with_session(&session_id, message).await
1688 }
1689
1690 pub async fn chat_stream_with_session(
1696 &self,
1697 session_id: &str,
1698 message: impl Into<String>,
1699 ) -> LLMResult<TextStream> {
1700 let message = message.into();
1701
1702 let model = self.provider.default_model();
1704
1705 let processed_message = if let Some(ref handler) = self.event_handler {
1707 match handler.before_chat_with_model(&message, model).await? {
1708 Some(msg) => msg,
1709 None => return Ok(Box::pin(futures::stream::empty())),
1710 }
1711 } else {
1712 message
1713 };
1714
1715 let session = self.get_session_arc(session_id).await?;
1717
1718 let history = {
1720 let session_guard = session.read().await;
1721 session_guard.messages().to_vec()
1722 };
1723
1724 let mut builder = self.client.chat();
1726
1727 if let Some(ref system) = self.config.system_prompt {
1728 builder = builder.system(system.clone());
1729 }
1730
1731 if let Some(temp) = self.config.temperature {
1732 builder = builder.temperature(temp);
1733 }
1734
1735 if let Some(tokens) = self.config.max_tokens {
1736 builder = builder.max_tokens(tokens);
1737 }
1738
1739 builder = builder.messages(history);
1741 builder = builder.user(processed_message.clone());
1742
1743 let chunk_stream = builder.send_stream().await?;
1745
1746 {
1748 let mut session_guard = session.write().await;
1749 session_guard
1750 .messages_mut()
1751 .push(ChatMessage::user(&processed_message));
1752 }
1753
1754 let event_handler = self.event_handler.clone().map(Arc::new);
1756 let wrapped_stream =
1757 Self::create_history_updating_stream(chunk_stream, session, event_handler);
1758
1759 Ok(wrapped_stream)
1760 }
1761
1762 pub async fn ask_stream_raw(&self, question: impl Into<String>) -> LLMResult<ChatStream> {
1766 let question = question.into();
1767
1768 let mut builder = self.client.chat();
1769
1770 if let Some(ref system) = self.config.system_prompt {
1771 builder = builder.system(system.clone());
1772 }
1773
1774 if let Some(temp) = self.config.temperature {
1775 builder = builder.temperature(temp);
1776 }
1777
1778 if let Some(tokens) = self.config.max_tokens {
1779 builder = builder.max_tokens(tokens);
1780 }
1781
1782 builder = builder.user(question);
1783
1784 builder.send_stream().await
1785 }
1786
1787 pub async fn chat_stream_with_full(
1808 &self,
1809 message: impl Into<String>,
1810 ) -> LLMResult<(TextStream, tokio::sync::oneshot::Receiver<String>)> {
1811 let session_id = self.active_session_id.read().await.clone();
1812 self.chat_stream_with_full_session(&session_id, message)
1813 .await
1814 }
1815
1816 pub async fn chat_stream_with_full_session(
1822 &self,
1823 session_id: &str,
1824 message: impl Into<String>,
1825 ) -> LLMResult<(TextStream, tokio::sync::oneshot::Receiver<String>)> {
1826 let message = message.into();
1827
1828 let model = self.provider.default_model();
1830
1831 let processed_message = if let Some(ref handler) = self.event_handler {
1833 match handler.before_chat_with_model(&message, model).await? {
1834 Some(msg) => msg,
1835 None => {
1836 let (tx, rx) = tokio::sync::oneshot::channel();
1837 let _ = tx.send(String::new());
1838 return Ok((Box::pin(futures::stream::empty()), rx));
1839 }
1840 }
1841 } else {
1842 message
1843 };
1844
1845 let session = self.get_session_arc(session_id).await?;
1847
1848 let history = {
1850 let session_guard = session.read().await;
1851 session_guard.messages().to_vec()
1852 };
1853
1854 let mut builder = self.client.chat();
1856
1857 if let Some(ref system) = self.config.system_prompt {
1858 builder = builder.system(system.clone());
1859 }
1860
1861 if let Some(temp) = self.config.temperature {
1862 builder = builder.temperature(temp);
1863 }
1864
1865 if let Some(tokens) = self.config.max_tokens {
1866 builder = builder.max_tokens(tokens);
1867 }
1868
1869 builder = builder.messages(history);
1870 builder = builder.user(processed_message.clone());
1871
1872 let chunk_stream = builder.send_stream().await?;
1873
1874 {
1876 let mut session_guard = session.write().await;
1877 session_guard
1878 .messages_mut()
1879 .push(ChatMessage::user(&processed_message));
1880 }
1881
1882 let (tx, rx) = tokio::sync::oneshot::channel();
1884
1885 let event_handler = self.event_handler.clone().map(Arc::new);
1887 let wrapped_stream =
1888 Self::create_collecting_stream(chunk_stream, session, tx, event_handler);
1889
1890 Ok((wrapped_stream, rx))
1891 }
1892
1893 fn chunk_stream_to_text_stream(chunk_stream: ChatStream) -> TextStream {
1899 use futures::StreamExt;
1900
1901 let text_stream = chunk_stream.filter_map(|result| async move {
1902 match result {
1903 Ok(chunk) => {
1904 if let Some(choice) = chunk.choices.first()
1906 && let Some(ref content) = choice.delta.content
1907 && !content.is_empty()
1908 {
1909 return Some(Ok(content.clone()));
1910 }
1911 None
1912 }
1913 Err(e) => Some(Err(e)),
1914 }
1915 });
1916
1917 Box::pin(text_stream)
1918 }
1919
1920 fn create_history_updating_stream(
1922 chunk_stream: ChatStream,
1923 session: Arc<RwLock<ChatSession>>,
1924 event_handler: Option<Arc<Box<dyn LLMAgentEventHandler>>>,
1925 ) -> TextStream {
1926 use super::types::LLMResponseMetadata;
1927
1928 let collected = Arc::new(tokio::sync::Mutex::new(String::new()));
1929 let collected_clone = collected.clone();
1930 let event_handler_clone = event_handler.clone();
1931 let metadata_collected = Arc::new(tokio::sync::Mutex::new(None::<LLMResponseMetadata>));
1932 let metadata_collected_clone = metadata_collected.clone();
1933
1934 let stream = chunk_stream.filter_map(move |result| {
1935 let collected = collected.clone();
1936 let event_handler = event_handler.clone();
1937 let metadata_collected = metadata_collected.clone();
1938 async move {
1939 match result {
1940 Ok(chunk) => {
1941 if let Some(choice) = chunk.choices.first() {
1942 if choice.finish_reason.is_some() {
1943 let metadata = LLMResponseMetadata::from(&chunk);
1945 *metadata_collected.lock().await = Some(metadata);
1946 return None;
1947 }
1948 if let Some(ref content) = choice.delta.content
1949 && !content.is_empty()
1950 {
1951 let mut collected = collected.lock().await;
1952 collected.push_str(content);
1953 return Some(Ok(content.clone()));
1954 }
1955 }
1956 None
1957 }
1958 Err(e) => {
1959 if let Some(handler) = event_handler {
1960 let _ = handler.on_error(&e).await;
1961 }
1962 Some(Err(e))
1963 }
1964 }
1965 }
1966 });
1967
1968 let stream = stream
1969 .chain(futures::stream::once(async move {
1970 let full_response = collected_clone.lock().await.clone();
1971 let metadata = metadata_collected_clone.lock().await.clone();
1972 if !full_response.is_empty() {
1973 let mut session = session.write().await;
1974 session
1975 .messages_mut()
1976 .push(ChatMessage::assistant(&full_response));
1977
1978 let window_size = session.context_window_size();
1980 if window_size.is_some() {
1981 let current_messages = session.messages().to_vec();
1982 *session.messages_mut() = ChatSession::apply_sliding_window_static(
1983 ¤t_messages,
1984 window_size,
1985 );
1986 }
1987
1988 if let Some(handler) = event_handler_clone {
1989 if let Some(meta) = &metadata {
1990 let _ = handler.after_chat_with_metadata(&full_response, meta).await;
1991 } else {
1992 let _ = handler.after_chat(&full_response).await;
1993 }
1994 }
1995 }
1996 Err(LLMError::Other("__stream_end__".to_string()))
1997 }))
1998 .filter_map(|result| async move {
1999 match result {
2000 Ok(s) => Some(Ok(s)),
2001 Err(e) if e.to_string() == "__stream_end__" => None,
2002 Err(e) => Some(Err(e)),
2003 }
2004 });
2005
2006 Box::pin(stream)
2007 }
2008
2009 fn create_collecting_stream(
2011 chunk_stream: ChatStream,
2012 session: Arc<RwLock<ChatSession>>,
2013 tx: tokio::sync::oneshot::Sender<String>,
2014 event_handler: Option<Arc<Box<dyn LLMAgentEventHandler>>>,
2015 ) -> TextStream {
2016 use super::types::LLMResponseMetadata;
2017 use futures::StreamExt;
2018
2019 let collected = Arc::new(tokio::sync::Mutex::new(String::new()));
2020 let collected_clone = collected.clone();
2021 let event_handler_clone = event_handler.clone();
2022 let metadata_collected = Arc::new(tokio::sync::Mutex::new(None::<LLMResponseMetadata>));
2023 let metadata_collected_clone = metadata_collected.clone();
2024
2025 let stream = chunk_stream.filter_map(move |result| {
2026 let collected = collected.clone();
2027 let event_handler = event_handler.clone();
2028 let metadata_collected = metadata_collected.clone();
2029 async move {
2030 match result {
2031 Ok(chunk) => {
2032 if let Some(choice) = chunk.choices.first() {
2033 if choice.finish_reason.is_some() {
2034 let metadata = LLMResponseMetadata::from(&chunk);
2036 *metadata_collected.lock().await = Some(metadata);
2037 return None;
2038 }
2039 if let Some(ref content) = choice.delta.content
2040 && !content.is_empty()
2041 {
2042 let mut collected = collected.lock().await;
2043 collected.push_str(content);
2044 return Some(Ok(content.clone()));
2045 }
2046 }
2047 None
2048 }
2049 Err(e) => {
2050 if let Some(handler) = event_handler {
2051 let _ = handler.on_error(&e).await;
2052 }
2053 Some(Err(e))
2054 }
2055 }
2056 }
2057 });
2058
2059 let stream = stream
2061 .chain(futures::stream::once(async move {
2062 let full_response = collected_clone.lock().await.clone();
2063 let mut processed_response = full_response.clone();
2064 let metadata = metadata_collected_clone.lock().await.clone();
2065
2066 if !full_response.is_empty() {
2067 let mut session = session.write().await;
2068 session
2069 .messages_mut()
2070 .push(ChatMessage::assistant(&processed_response));
2071
2072 let window_size = session.context_window_size();
2074 if window_size.is_some() {
2075 let current_messages = session.messages().to_vec();
2076 *session.messages_mut() = ChatSession::apply_sliding_window_static(
2077 ¤t_messages,
2078 window_size,
2079 );
2080 }
2081
2082 if let Some(handler) = event_handler_clone {
2084 if let Some(meta) = &metadata {
2085 if let Ok(Some(resp)) = handler
2086 .after_chat_with_metadata(&processed_response, meta)
2087 .await
2088 {
2089 processed_response = resp;
2090 }
2091 } else if let Ok(Some(resp)) = handler.after_chat(&processed_response).await
2092 {
2093 processed_response = resp;
2094 }
2095 }
2096 }
2097
2098 let _ = tx.send(processed_response);
2099
2100 Err(LLMError::Other("__stream_end__".to_string()))
2101 }))
2102 .filter_map(|result| async move {
2103 match result {
2104 Ok(s) => Some(Ok(s)),
2105 Err(e) if e.to_string() == "__stream_end__" => None,
2106 Err(e) => Some(Err(e)),
2107 }
2108 });
2109
2110 Box::pin(stream)
2111 }
2112}
2113
2114pub struct LLMAgentBuilder {
2116 agent_id: String,
2117 name: Option<String>,
2118 provider: Option<Arc<dyn LLMProvider>>,
2119 system_prompt: Option<String>,
2120 temperature: Option<f32>,
2121 max_tokens: Option<u32>,
2122 tools: Vec<Tool>,
2123 tool_executor: Option<Arc<dyn ToolExecutor>>,
2124 event_handler: Option<Box<dyn LLMAgentEventHandler>>,
2125 plugins: Vec<Box<dyn AgentPlugin>>,
2126 custom_config: HashMap<String, String>,
2127 prompt_plugin: Option<Box<dyn prompt::PromptTemplatePlugin>>,
2128 session_id: Option<String>,
2129 user_id: Option<String>,
2130 tenant_id: Option<String>,
2131 context_window_size: Option<usize>,
2132 message_store: Option<Arc<dyn crate::persistence::MessageStore + Send + Sync>>,
2134 session_store: Option<Arc<dyn crate::persistence::SessionStore + Send + Sync>>,
2135 persistence_user_id: Option<uuid::Uuid>,
2136 persistence_tenant_id: Option<uuid::Uuid>,
2137 persistence_agent_id: Option<uuid::Uuid>,
2138}
2139
2140impl LLMAgentBuilder {
2141 pub fn new() -> Self {
2143 Self {
2144 agent_id: uuid::Uuid::now_v7().to_string(),
2145 name: None,
2146 provider: None,
2147 system_prompt: None,
2148 temperature: None,
2149 max_tokens: None,
2150 tools: Vec::new(),
2151 tool_executor: None,
2152 event_handler: None,
2153 plugins: Vec::new(),
2154 custom_config: HashMap::new(),
2155 prompt_plugin: None,
2156 session_id: None,
2157 user_id: None,
2158 tenant_id: None,
2159 context_window_size: None,
2160 message_store: None,
2161 session_store: None,
2162 persistence_user_id: None,
2163 persistence_tenant_id: None,
2164 persistence_agent_id: None,
2165 }
2166 }
2167
2168 pub fn with_id(mut self, id: impl Into<String>) -> Self {
2170 self.agent_id = id.into();
2171 self
2172 }
2173
2174 pub fn with_name(mut self, name: impl Into<String>) -> Self {
2176 self.name = Some(name.into());
2177 self
2178 }
2179
2180 pub fn with_provider(mut self, provider: Arc<dyn LLMProvider>) -> Self {
2182 self.provider = Some(provider);
2183 self
2184 }
2185
2186 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
2188 self.system_prompt = Some(prompt.into());
2189 self
2190 }
2191
2192 pub fn with_temperature(mut self, temperature: f32) -> Self {
2194 self.temperature = Some(temperature);
2195 self
2196 }
2197
2198 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
2200 self.max_tokens = Some(max_tokens);
2201 self
2202 }
2203
2204 pub fn with_tool(mut self, tool: Tool) -> Self {
2206 self.tools.push(tool);
2207 self
2208 }
2209
2210 pub fn with_tools(mut self, tools: Vec<Tool>) -> Self {
2212 self.tools = tools;
2213 self
2214 }
2215
2216 pub fn with_tool_executor(mut self, executor: Arc<dyn ToolExecutor>) -> Self {
2218 self.tool_executor = Some(executor);
2219 self
2220 }
2221
2222 pub fn with_event_handler(mut self, handler: Box<dyn LLMAgentEventHandler>) -> Self {
2224 self.event_handler = Some(handler);
2225 self
2226 }
2227
2228 pub fn with_plugin(mut self, plugin: impl AgentPlugin + 'static) -> Self {
2230 self.plugins.push(Box::new(plugin));
2231 self
2232 }
2233
2234 pub fn with_plugins(mut self, plugins: Vec<Box<dyn AgentPlugin>>) -> Self {
2236 self.plugins.extend(plugins);
2237 self
2238 }
2239
2240 pub fn with_persistence_plugin(
2278 mut self,
2279 plugin: crate::persistence::PersistencePlugin,
2280 ) -> Self {
2281 self.message_store = Some(plugin.message_store());
2282 self.session_store = plugin.session_store();
2283 self.persistence_user_id = Some(plugin.user_id());
2284 self.persistence_tenant_id = Some(plugin.tenant_id());
2285 self.persistence_agent_id = Some(plugin.agent_id());
2286
2287 let plugin_box: Box<dyn AgentPlugin> = Box::new(plugin.clone());
2290 let event_handler: Box<dyn LLMAgentEventHandler> = Box::new(plugin);
2291 self.plugins.push(plugin_box);
2292 self.event_handler = Some(event_handler);
2293 self
2294 }
2295
2296 pub fn with_prompt_plugin(
2298 mut self,
2299 plugin: impl prompt::PromptTemplatePlugin + 'static,
2300 ) -> Self {
2301 self.prompt_plugin = Some(Box::new(plugin));
2302 self
2303 }
2304
2305 pub fn with_hot_reload_prompt_plugin(
2307 mut self,
2308 plugin: prompt::HotReloadableRhaiPromptPlugin,
2309 ) -> Self {
2310 self.prompt_plugin = Some(Box::new(plugin));
2311 self
2312 }
2313
2314 pub fn with_config(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
2316 self.custom_config.insert(key.into(), value.into());
2317 self
2318 }
2319
2320 pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
2331 self.session_id = Some(session_id.into());
2332 self
2333 }
2334
2335 pub fn with_user(mut self, user_id: impl Into<String>) -> Self {
2348 self.user_id = Some(user_id.into());
2349 self
2350 }
2351
2352 pub fn with_tenant(mut self, tenant_id: impl Into<String>) -> Self {
2365 self.tenant_id = Some(tenant_id.into());
2366 self
2367 }
2368
2369 pub fn with_sliding_window(mut self, size: usize) -> Self {
2392 self.context_window_size = Some(size);
2393 self
2394 }
2395
2396 pub fn from_env() -> LLMResult<Self> {
2417 use super::openai::{OpenAIConfig, OpenAIProvider};
2418
2419 let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| {
2420 LLMError::ConfigError("OPENAI_API_KEY environment variable not set".to_string())
2421 })?;
2422
2423 let mut config = OpenAIConfig::new(api_key);
2424
2425 if let Ok(base_url) = std::env::var("OPENAI_BASE_URL") {
2426 config = config.with_base_url(&base_url);
2427 }
2428
2429 if let Ok(model) = std::env::var("OPENAI_MODEL") {
2430 config = config.with_model(&model);
2431 }
2432
2433 Ok(Self::new()
2434 .with_provider(Arc::new(OpenAIProvider::with_config(config)))
2435 .with_temperature(0.7)
2436 .with_max_tokens(4096))
2437 }
2438
2439 pub fn build(self) -> LLMAgent {
2444 let provider = self
2445 .provider
2446 .expect("LLM provider must be set before building");
2447
2448 let config = LLMAgentConfig {
2449 agent_id: self.agent_id.clone(),
2450 name: self.name.unwrap_or_else(|| self.agent_id.clone()),
2451 system_prompt: self.system_prompt,
2452 temperature: self.temperature,
2453 max_tokens: self.max_tokens,
2454 custom_config: self.custom_config,
2455 user_id: self.user_id,
2456 tenant_id: self.tenant_id,
2457 context_window_size: self.context_window_size,
2458 };
2459
2460 let mut agent = LLMAgent::with_initial_session(config, provider, self.session_id);
2461
2462 agent.prompt_plugin = self.prompt_plugin;
2464
2465 if let Some(executor) = self.tool_executor {
2466 agent.set_tools(self.tools, executor);
2467 }
2468
2469 if let Some(handler) = self.event_handler {
2470 agent.set_event_handler(handler);
2471 }
2472
2473 let mut plugins = self.plugins;
2475 let mut tts_plugin = None;
2476
2477 for i in (0..plugins.len()).rev() {
2479 if plugins[i].as_any().is::<mofa_plugins::tts::TTSPlugin>() {
2480 let plugin = plugins.remove(i);
2483 if let Ok(tts) = plugin.into_any().downcast::<mofa_plugins::tts::TTSPlugin>() {
2485 tts_plugin = Some(Arc::new(Mutex::new(*tts)));
2486 }
2487 }
2488 }
2489
2490 agent.add_plugins(plugins);
2492
2493 agent.tts_plugin = tts_plugin;
2495
2496 agent
2497 }
2498
2499 pub fn try_build(self) -> LLMResult<LLMAgent> {
2503 let provider = self
2504 .provider
2505 .ok_or_else(|| LLMError::ConfigError("LLM provider not set".to_string()))?;
2506
2507 let config = LLMAgentConfig {
2508 agent_id: self.agent_id.clone(),
2509 name: self.name.unwrap_or_else(|| self.agent_id.clone()),
2510 system_prompt: self.system_prompt,
2511 temperature: self.temperature,
2512 max_tokens: self.max_tokens,
2513 custom_config: self.custom_config,
2514 user_id: self.user_id,
2515 tenant_id: self.tenant_id,
2516 context_window_size: self.context_window_size,
2517 };
2518
2519 let mut agent = LLMAgent::with_initial_session(config, provider, self.session_id);
2520
2521 if let Some(executor) = self.tool_executor {
2522 agent.set_tools(self.tools, executor);
2523 }
2524
2525 if let Some(handler) = self.event_handler {
2526 agent.set_event_handler(handler);
2527 }
2528
2529 let mut plugins = self.plugins;
2531 let mut tts_plugin = None;
2532
2533 for i in (0..plugins.len()).rev() {
2535 if plugins[i].as_any().is::<mofa_plugins::tts::TTSPlugin>() {
2536 let plugin = plugins.remove(i);
2539 if let Ok(tts) = plugin.into_any().downcast::<mofa_plugins::tts::TTSPlugin>() {
2541 tts_plugin = Some(Arc::new(Mutex::new(*tts)));
2542 }
2543 }
2544 }
2545
2546 agent.add_plugins(plugins);
2548
2549 agent.tts_plugin = tts_plugin;
2551
2552 Ok(agent)
2553 }
2554
2555 pub async fn build_async(mut self) -> LLMAgent {
2586 let provider = self
2587 .provider
2588 .expect("LLM provider must be set before building");
2589
2590 let tenant_id_for_persistence = self.tenant_id.clone();
2592
2593 let config = LLMAgentConfig {
2594 agent_id: self.agent_id.clone(),
2595 name: self.name.unwrap_or_else(|| self.agent_id.clone()),
2596 system_prompt: self.system_prompt,
2597 temperature: self.temperature,
2598 max_tokens: self.max_tokens,
2599 custom_config: self.custom_config,
2600 user_id: self.user_id,
2601 tenant_id: self.tenant_id,
2602 context_window_size: self.context_window_size,
2603 };
2604
2605 let persistence_tenant_id = if self.session_store.is_some()
2607 && self.persistence_tenant_id.is_none()
2608 && tenant_id_for_persistence.is_some()
2609 {
2610 uuid::Uuid::parse_str(&tenant_id_for_persistence.unwrap()).ok()
2611 } else {
2612 self.persistence_tenant_id
2613 };
2614
2615 let mut agent = LLMAgent::with_initial_session_async(
2617 config,
2618 provider,
2619 self.session_id,
2620 self.message_store,
2621 self.session_store,
2622 self.persistence_user_id,
2623 persistence_tenant_id,
2624 self.persistence_agent_id,
2625 )
2626 .await;
2627
2628 agent.prompt_plugin = self.prompt_plugin;
2630
2631 if self.tools.is_empty() {
2632 if let Some(executor) = self.tool_executor.as_ref() {
2633 if let Ok(tools) = executor.available_tools().await {
2634 self.tools = tools;
2635 }
2636 }
2637 }
2638
2639 if let Some(executor) = self.tool_executor {
2640 agent.set_tools(self.tools, executor);
2641 }
2642
2643 let mut plugins = self.plugins;
2647 let mut tts_plugin = None;
2648 let history_loaded_from_plugin = false;
2649
2650 for i in (0..plugins.len()).rev() {
2652 if plugins[i].as_any().is::<mofa_plugins::tts::TTSPlugin>() {
2653 let plugin = plugins.remove(i);
2656 if let Ok(tts) = plugin.into_any().downcast::<mofa_plugins::tts::TTSPlugin>() {
2658 tts_plugin = Some(Arc::new(Mutex::new(*tts)));
2659 }
2660 }
2661 }
2662
2663 if !history_loaded_from_plugin {
2665 for plugin in &plugins {
2666 if plugin.metadata().plugin_type == PluginType::Storage
2668 && plugin
2669 .metadata()
2670 .capabilities
2671 .contains(&"message_persistence".to_string())
2672 {
2673 tracing::info!("📦 检测到持久化插件,将在 agent 初始化后加载历史");
2677 break;
2678 }
2679 }
2680 }
2681
2682 agent.add_plugins(plugins);
2684
2685 agent.tts_plugin = tts_plugin;
2687
2688 if let Some(handler) = self.event_handler {
2690 agent.set_event_handler(handler);
2691 }
2692
2693 agent
2694 }
2695}
2696
2697impl LLMAgentBuilder {
2702 pub fn from_config_file(path: impl AsRef<std::path::Path>) -> LLMResult<Self> {
2713 let config = crate::config::AgentYamlConfig::from_file(path)
2714 .map_err(|e| LLMError::ConfigError(e.to_string()))?;
2715 Self::from_yaml_config(config)
2716 }
2717
2718 pub fn from_yaml_config(config: crate::config::AgentYamlConfig) -> LLMResult<Self> {
2720 let mut builder = Self::new()
2721 .with_id(&config.agent.id)
2722 .with_name(&config.agent.name);
2723 if let Some(llm_config) = config.llm {
2725 let provider = create_provider_from_config(&llm_config)?;
2726 builder = builder.with_provider(Arc::new(provider));
2727
2728 if let Some(temp) = llm_config.temperature {
2729 builder = builder.with_temperature(temp);
2730 }
2731 if let Some(tokens) = llm_config.max_tokens {
2732 builder = builder.with_max_tokens(tokens);
2733 }
2734 if let Some(prompt) = llm_config.system_prompt {
2735 builder = builder.with_system_prompt(prompt);
2736 }
2737 }
2738
2739 Ok(builder)
2740 }
2741
2742 #[cfg(feature = "persistence-postgres")]
2768 pub async fn from_database<S>(store: &S, agent_code: &str) -> LLMResult<Self>
2769 where
2770 S: crate::persistence::AgentStore + Send + Sync,
2771 {
2772 let config = store
2773 .get_agent_by_code_with_provider(agent_code)
2774 .await
2775 .map_err(|e| LLMError::Other(format!("Failed to load agent from database: {}", e)))?
2776 .ok_or_else(|| {
2777 LLMError::Other(format!(
2778 "Agent with code '{}' not found in database",
2779 agent_code
2780 ))
2781 })?;
2782
2783 Self::from_agent_config(&config)
2784 }
2785
2786 #[cfg(feature = "persistence-postgres")]
2811 pub async fn from_database_with_tenant<S>(
2812 store: &S,
2813 tenant_id: uuid::Uuid,
2814 agent_code: &str,
2815 ) -> LLMResult<Self>
2816 where
2817 S: crate::persistence::AgentStore + Send + Sync,
2818 {
2819 let config = store
2820 .get_agent_by_code_and_tenant_with_provider(tenant_id, agent_code)
2821 .await
2822 .map_err(|e| LLMError::Other(format!("Failed to load agent from database: {}", e)))?
2823 .ok_or_else(|| {
2824 LLMError::Other(format!(
2825 "Agent with code '{}' not found for tenant {}",
2826 agent_code, tenant_id
2827 ))
2828 })?;
2829
2830 Self::from_agent_config(&config)
2831 }
2832
2833 #[cfg(feature = "persistence-postgres")]
2847 pub async fn with_database_agent<S>(store: &S, agent_code: &str) -> LLMResult<Self>
2848 where
2849 S: crate::persistence::AgentStore + Send + Sync,
2850 {
2851 Self::from_database(store, agent_code).await
2852 }
2853
2854 #[cfg(feature = "persistence-postgres")]
2856 pub fn from_agent_config(config: &crate::persistence::AgentConfig) -> LLMResult<Self> {
2857 use super::openai::{OpenAIConfig, OpenAIProvider};
2858
2859 let agent = &config.agent;
2860 let provider = &config.provider;
2861
2862 if !agent.agent_status {
2864 return Err(LLMError::Other(format!(
2865 "Agent '{}' is disabled (agent_status = false)",
2866 agent.agent_code
2867 )));
2868 }
2869
2870 if !provider.enabled {
2872 return Err(LLMError::Other(format!(
2873 "Provider '{}' is disabled (enabled = false)",
2874 provider.provider_name
2875 )));
2876 }
2877
2878 let llm_provider: Arc<dyn super::LLMProvider> = match provider.provider_type.as_str() {
2880 "openai" | "azure" | "compatible" | "local" => {
2881 let mut openai_config = OpenAIConfig::new(provider.api_key.clone());
2882 openai_config = openai_config.with_base_url(&provider.api_base);
2883 openai_config = openai_config.with_model(&agent.model_name);
2884
2885 if let Some(temp) = agent.temperature {
2886 openai_config = openai_config.with_temperature(temp);
2887 }
2888
2889 if let Some(max_tokens) = agent.max_completion_tokens {
2890 openai_config = openai_config.with_max_tokens(max_tokens as u32);
2891 }
2892
2893 Arc::new(OpenAIProvider::with_config(openai_config))
2894 }
2895 "anthropic" => {
2896 let mut cfg = AnthropicConfig::new(provider.api_key.clone());
2897 cfg = cfg.with_base_url(&provider.api_base);
2898 cfg = cfg.with_model(&agent.model_name);
2899
2900 if let Some(temp) = agent.temperature {
2901 cfg = cfg.with_temperature(temp);
2902 }
2903 if let Some(tokens) = agent.max_completion_tokens {
2904 cfg = cfg.with_max_tokens(tokens as u32);
2905 }
2906
2907 Arc::new(AnthropicProvider::with_config(cfg))
2908 }
2909 "gemini" => {
2910 let mut cfg = GeminiConfig::new(provider.api_key.clone());
2911 cfg = cfg.with_base_url(&provider.api_base);
2912 cfg = cfg.with_model(&agent.model_name);
2913
2914 if let Some(temp) = agent.temperature {
2915 cfg = cfg.with_temperature(temp);
2916 }
2917 if let Some(tokens) = agent.max_completion_tokens {
2918 cfg = cfg.with_max_tokens(tokens as u32);
2919 }
2920
2921 Arc::new(GeminiProvider::with_config(cfg))
2922 }
2923 "ollama" => {
2924 let mut ollama_config = OllamaConfig::new();
2925 ollama_config = ollama_config.with_base_url(&provider.api_base);
2926 ollama_config = ollama_config.with_model(&agent.model_name);
2927
2928 if let Some(temp) = agent.temperature {
2929 ollama_config = ollama_config.with_temperature(temp);
2930 }
2931
2932 if let Some(max_tokens) = agent.max_completion_tokens {
2933 ollama_config = ollama_config.with_max_tokens(max_tokens as u32);
2934 }
2935
2936 Arc::new(OllamaProvider::with_config(ollama_config))
2937 }
2938 other => {
2939 return Err(LLMError::Other(format!(
2940 "Unsupported provider type: {}",
2941 other
2942 )));
2943 }
2944 };
2945
2946 let mut builder = Self::new()
2948 .with_id(agent.id.clone())
2949 .with_name(agent.agent_name.clone())
2950 .with_provider(llm_provider)
2951 .with_system_prompt(agent.system_prompt.clone())
2952 .with_tenant(agent.tenant_id.to_string());
2953
2954 if let Some(temp) = agent.temperature {
2956 builder = builder.with_temperature(temp);
2957 }
2958 if let Some(tokens) = agent.max_completion_tokens {
2959 builder = builder.with_max_tokens(tokens as u32);
2960 }
2961 if let Some(limit) = agent.context_limit {
2962 builder = builder.with_sliding_window(limit as usize);
2963 }
2964
2965 if let Some(ref params) = agent.custom_params {
2967 if let Some(obj) = params.as_object() {
2968 for (key, value) in obj.iter() {
2969 let value_str: String = match value {
2970 serde_json::Value::String(s) => s.clone(),
2971 serde_json::Value::Bool(b) => b.to_string(),
2972 serde_json::Value::Number(n) => n.to_string(),
2973 _ => value.to_string(),
2974 };
2975 builder = builder.with_config(key.as_str(), value_str);
2976 }
2977 }
2978 }
2979
2980 if let Some(ref format) = agent.response_format {
2982 builder = builder.with_config("response_format", format);
2983 }
2984
2985 if let Some(stream) = agent.stream {
2987 builder = builder.with_config("stream", if stream { "true" } else { "false" });
2988 }
2989
2990 Ok(builder)
2991 }
2992}
2993
2994fn create_provider_from_config(
2997 config: &crate::config::LLMYamlConfig,
2998) -> LLMResult<super::openai::OpenAIProvider> {
2999 use super::openai::{OpenAIConfig, OpenAIProvider};
3000
3001 match config.provider.as_str() {
3002 "openai" => {
3003 let api_key = config
3004 .api_key
3005 .clone()
3006 .or_else(|| std::env::var("OPENAI_API_KEY").ok())
3007 .ok_or_else(|| LLMError::ConfigError("OpenAI API key not set".to_string()))?;
3008
3009 let mut openai_config = OpenAIConfig::new(api_key);
3010
3011 if let Some(ref model) = config.model {
3012 openai_config = openai_config.with_model(model);
3013 }
3014 if let Some(ref base_url) = config.base_url {
3015 openai_config = openai_config.with_base_url(base_url);
3016 }
3017 if let Some(temp) = config.temperature {
3018 openai_config = openai_config.with_temperature(temp);
3019 }
3020 if let Some(tokens) = config.max_tokens {
3021 openai_config = openai_config.with_max_tokens(tokens);
3022 }
3023
3024 Ok(OpenAIProvider::with_config(openai_config))
3025 }
3026 "azure" => {
3027 let endpoint = config.base_url.clone().ok_or_else(|| {
3028 LLMError::ConfigError("Azure endpoint (base_url) not set".to_string())
3029 })?;
3030 let api_key = config
3031 .api_key
3032 .clone()
3033 .or_else(|| std::env::var("AZURE_OPENAI_API_KEY").ok())
3034 .ok_or_else(|| LLMError::ConfigError("Azure API key not set".to_string()))?;
3035 let deployment = config
3036 .deployment
3037 .clone()
3038 .or_else(|| config.model.clone())
3039 .ok_or_else(|| {
3040 LLMError::ConfigError("Azure deployment name not set".to_string())
3041 })?;
3042
3043 Ok(OpenAIProvider::azure(endpoint, api_key, deployment))
3044 }
3045 "compatible" | "local" => {
3046 let base_url = config.base_url.clone().ok_or_else(|| {
3047 LLMError::ConfigError("base_url not set for compatible provider".to_string())
3048 })?;
3049 let model = config
3050 .model
3051 .clone()
3052 .unwrap_or_else(|| "default".to_string());
3053
3054 Ok(OpenAIProvider::local(base_url, model))
3055 }
3056 other => Err(LLMError::ConfigError(format!(
3057 "Unknown provider: {}",
3058 other
3059 ))),
3060 }
3061}
3062
3063#[async_trait::async_trait]
3068impl mofa_kernel::agent::MoFAAgent for LLMAgent {
3069 fn id(&self) -> &str {
3070 &self.metadata.id
3071 }
3072
3073 fn name(&self) -> &str {
3074 &self.metadata.name
3075 }
3076
3077 fn capabilities(&self) -> &mofa_kernel::agent::AgentCapabilities {
3078 use mofa_kernel::agent::AgentCapabilities;
3083
3084 static CAPABILITIES: std::sync::OnceLock<AgentCapabilities> = std::sync::OnceLock::new();
3088
3089 CAPABILITIES.get_or_init(|| {
3090 AgentCapabilities::builder()
3091 .tag("llm")
3092 .tag("chat")
3093 .tag("text-generation")
3094 .input_type(mofa_kernel::agent::InputType::Text)
3095 .output_type(mofa_kernel::agent::OutputType::Text)
3096 .supports_streaming(true)
3097 .supports_tools(true)
3098 .build()
3099 })
3100 }
3101
3102 async fn initialize(
3103 &mut self,
3104 ctx: &mofa_kernel::agent::AgentContext,
3105 ) -> mofa_kernel::agent::AgentResult<()> {
3106 let mut plugin_config = mofa_kernel::plugin::PluginConfig::new();
3108 for (k, v) in &self.config.custom_config {
3109 plugin_config.set(k, v);
3110 }
3111 if let Some(user_id) = &self.config.user_id {
3112 plugin_config.set("user_id", user_id);
3113 }
3114 if let Some(tenant_id) = &self.config.tenant_id {
3115 plugin_config.set("tenant_id", tenant_id);
3116 }
3117 let session_id = self.active_session_id.read().await.clone();
3118 plugin_config.set("session_id", session_id);
3119
3120 let plugin_ctx =
3121 mofa_kernel::plugin::PluginContext::new(self.id()).with_config(plugin_config);
3122
3123 for plugin in &mut self.plugins {
3124 plugin
3125 .load(&plugin_ctx)
3126 .await
3127 .map_err(|e| mofa_kernel::agent::AgentError::InitializationFailed(e.to_string()))?;
3128 plugin
3129 .init_plugin()
3130 .await
3131 .map_err(|e| mofa_kernel::agent::AgentError::InitializationFailed(e.to_string()))?;
3132 }
3133 self.state = mofa_kernel::agent::AgentState::Ready;
3134
3135 let _ = ctx;
3137
3138 Ok(())
3139 }
3140
3141 async fn execute(
3142 &mut self,
3143 input: mofa_kernel::agent::AgentInput,
3144 _ctx: &mofa_kernel::agent::AgentContext,
3145 ) -> mofa_kernel::agent::AgentResult<mofa_kernel::agent::AgentOutput> {
3146 use mofa_kernel::agent::{AgentError, AgentInput, AgentOutput};
3147
3148 let message = match input {
3150 AgentInput::Text(text) => text,
3151 AgentInput::Json(json) => json.to_string(),
3152 _ => {
3153 return Err(AgentError::ValidationFailed(
3154 "Unsupported input type for LLMAgent".to_string(),
3155 ));
3156 }
3157 };
3158
3159 let response = self
3161 .chat(&message)
3162 .await
3163 .map_err(|e| AgentError::ExecutionFailed(format!("LLM chat failed: {}", e)))?;
3164
3165 Ok(AgentOutput::text(response))
3167 }
3168
3169 async fn shutdown(&mut self) -> mofa_kernel::agent::AgentResult<()> {
3170 for plugin in &mut self.plugins {
3172 plugin
3173 .unload()
3174 .await
3175 .map_err(|e| mofa_kernel::agent::AgentError::ShutdownFailed(e.to_string()))?;
3176 }
3177 self.state = mofa_kernel::agent::AgentState::Shutdown;
3178 Ok(())
3179 }
3180
3181 fn state(&self) -> mofa_kernel::agent::AgentState {
3182 self.state.clone()
3183 }
3184}
3185
3186pub fn simple_llm_agent(
3205 agent_id: impl Into<String>,
3206 provider: Arc<dyn LLMProvider>,
3207 system_prompt: impl Into<String>,
3208) -> LLMAgent {
3209 LLMAgentBuilder::new()
3210 .with_id(agent_id)
3211 .with_provider(provider)
3212 .with_system_prompt(system_prompt)
3213 .build()
3214}
3215
3216pub fn agent_from_config(path: impl AsRef<std::path::Path>) -> LLMResult<LLMAgent> {
3227 LLMAgentBuilder::from_config_file(path)?.try_build()
3228}