1use std::collections::HashMap;
2
3use serde::de::DeserializeOwned;
4use serde_json::json;
5
6use error_stack::{Report, Result, ResultExt};
7use thiserror::Error;
8
9use tracing::info;
10
11use crate::chat::chat_base::{BaseChat, ChatError};
12use crate::chat::chat_tool::ChatTool;
13use crate::chat::message::Role;
14use crate::config::ModelCapability;
15use crate::prompt::assembler::assemble_output_description;
16use crate::schema::json_schema::JsonSchema;
17
18#[derive(Debug, Clone)]
19pub struct MultiChat {
20    pub base: BaseChat,
21
22    character_prompts: HashMap<String, String>,
23
24    pub current_character: String,
25
26    need_stream: bool,
27}
28
29impl MultiChat {
30    pub fn new_with_api_name(
31        api_name: &str,
32        character_prompts: HashMap<String, String>,
33        need_stream: bool,
34    ) -> Result<Self, ChatError> {
35        if character_prompts.is_empty() {
36            return Err(Report::new(ChatError::NoCharacterPrompts));
37        }
38
39        Ok(Self {
40            base: BaseChat::new_with_api_name(api_name, "", need_stream),
41            character_prompts,
42            current_character: String::new(),
43            need_stream,
44        })
45    }
46
47    pub fn new_with_model_capability(
48        model_capability: ModelCapability,
49        character_prompts: HashMap<String, String>,
50        need_stream: bool,
51    ) -> Result<Self, ChatError> {
52        if character_prompts.is_empty() {
53            return Err(Report::new(ChatError::NoCharacterPrompts));
54        }
55
56        Ok(Self {
57            base: BaseChat::new_with_model_capability(model_capability, "", need_stream),
58            character_prompts,
59            current_character: String::new(),
60            need_stream,
61        })
62    }
63
64    pub fn set_character(&mut self, character: &str) -> Result<(), ChatError> {
65        if !self.character_prompts.contains_key(character) {
66            return Err(Report::new(ChatError::UndefinedCharacter(
67                character.to_owned(),
68            )));
69        }
70        self.current_character = character.to_owned();
71        self.base.character_prompt = self.character_prompts[&self.current_character].clone();
72        Ok(())
73    }
74
75    pub fn add_user_message(&mut self, content: &str) -> Result<(), ChatError> {
76        self.base.add_message(Role::User, content)
77    }
78
79    pub fn add_system_message(&mut self, content: &str) -> Result<(), ChatError> {
80        self.base.add_message(Role::System, content)
81    }
82
83    pub fn add_message_with_parent_path(
84        &mut self,
85        path: &[usize],
86        role: Role,
87        content: &str,
88    ) -> Result<(), ChatError> {
89        self.base.add_message_with_parent_path(path, role, content)
90    }
91
92    pub async fn get_req_body_with_new_question(
93        &mut self,
94        parent_path: &[usize],
95        user_input: &str,
96    ) -> Result<serde_json::Value, ChatError> {
97        if self.current_character.is_empty() {
98            return Err(Report::new(ChatError::NoCharacterSelected));
99        }
100
101        self.base
102            .add_message_with_parent_path(parent_path, Role::User, user_input)?;
103
104        let character_role = Role::Character(self.current_character.clone());
105
106        Ok(self
107            .base
108            .build_request_body(&self.base.session.default_path.clone(), &character_role)?)
109    }
110
111    pub async fn get_req_body_again(
112        &mut self,
113        end_path: &[usize],
114    ) -> Result<serde_json::Value, ChatError> {
115        if self.current_character.is_empty() {
116            return Err(Report::new(ChatError::NoCharacterSelected));
117        }
118
119        let character_role = Role::Character(self.current_character.clone());
120
121        Ok(self.base.build_request_body(end_path, &character_role)?)
122    }
123
124    pub async fn get_req_body(&mut self, user_input: &str) -> Result<serde_json::Value, ChatError> {
125        info!("path: {:?}", self.base.session.default_path.clone());
126        self.get_req_body_with_new_question(&self.base.session.default_path.clone(), user_input)
127            .await
128    }
129
130    async fn get_content_from_req_body(
131        &mut self,
132        request_body: serde_json::Value,
133    ) -> Result<String, ChatError> {
134        let content = if self.need_stream {
135            let (stream, semaphore_permit) = self
136                .base
137                .get_stream_response(request_body.clone())
138                .await
139                .attach_printable("Failed to get stream response")?;
140
141            BaseChat::get_content_from_stream_resp(stream, semaphore_permit)
142                .await
143                .attach_printable("Failed to extract content from stream response")?
144        } else {
145            let response = self
146                .base
147                .get_response(request_body.clone())
148                .await
149                .attach_printable("Failed to get response")?;
150
151            BaseChat::get_content_from_resp(&response)
152                .attach_printable("Failed to extract content from response")?
153        };
154
155        info!(
156            "GetLLMAPIAnswer from {}: {}",
157            self.current_character, content
158        );
159
160        let character_role = Role::Character(self.current_character.clone());
161        self.base.add_message(character_role, &content)?;
162
163        Ok(content)
164    }
165
166    pub async fn get_answer(&mut self, user_input: &str) -> Result<String, ChatError> {
167        if self.current_character.is_empty() {
168            return Err(Report::new(ChatError::NoCharacterSelected));
169        }
170
171        let request_body = self.get_req_body(user_input).await?;
172
173        self.get_content_from_req_body(request_body).await
174    }
175
176    pub async fn get_json_answer<T: DeserializeOwned + 'static + JsonSchema>(
177        &mut self,
178        user_input: &str,
179    ) -> Result<T, ChatError> {
180        let schema = T::json_schema();
181
182        let output_description = assemble_output_description(schema.clone())
183            .change_context(ChatError::AssembleOutputDescriptionError)
184            .attach_printable(format!(
185                "Failed to assemble output description for schema: {:?}",
186                serde_json::to_string(&schema)
187                    .unwrap_or_else(|_| "Schema serialization failed".to_string())
188            ))?;
189
190        self.base
191            .add_message(Role::System, output_description.as_str())?;
192
193        let answer = self.get_answer(user_input).await?;
194
195        ChatTool::get_json::<T>(&answer, schema)
196            .await
197            .attach_printable(format!("Failed to parse answer as JSON: {}", answer))
198    }
199
200    pub async fn dialogue(
201        &mut self,
202        character: &str,
203        user_input: &str,
204    ) -> Result<String, ChatError> {
205        self.set_character(character)?;
206        self.add_user_message(user_input)?;
207        self.get_answer(user_input).await
208    }
209
210    pub async fn structured_dialogue<T: DeserializeOwned + 'static + JsonSchema>(
211        &mut self,
212        character: &str,
213        user_input: &str,
214    ) -> Result<T, ChatError> {
215        self.set_character(character)?;
216        self.add_user_message(user_input)?;
217        self.get_json_answer::<T>(user_input).await
218    }
219}