1use anyhow::Result;
2use futures_util::StreamExt;
3use inquire::Text;
4use inquire::ui::{Attributes, Color, RenderConfig, StyleSheet, Styled};
5use std::collections::HashMap;
6use std::io::{self, Write};
7
8use super::command::{Input, SlashCommand, SlashCommandCompleter, parse_input};
9use super::ui;
10use crate::config::CustomStyle;
11use crate::style;
12use crate::translation::{TranslationClient, TranslationRequest};
13use crate::ui::{Spinner, Style};
14
15#[derive(Debug, Clone)]
17pub struct SessionConfig {
18 pub provider_name: String,
20 pub endpoint: String,
22 pub model: String,
24 pub api_key: Option<String>,
26 pub to: String,
28 pub style_name: Option<String>,
30 pub style_prompt: Option<String>,
32 pub custom_styles: HashMap<String, CustomStyle>,
34}
35
36impl SessionConfig {
37 #[allow(clippy::missing_const_for_fn)] #[allow(clippy::too_many_arguments)]
40 pub fn new(
41 provider_name: String,
42 endpoint: String,
43 model: String,
44 api_key: Option<String>,
45 to: String,
46 style_name: Option<String>,
47 style_prompt: Option<String>,
48 custom_styles: HashMap<String, CustomStyle>,
49 ) -> Self {
50 Self {
51 provider_name,
52 endpoint,
53 model,
54 api_key,
55 to,
56 style_name,
57 style_prompt,
58 custom_styles,
59 }
60 }
61}
62
63pub struct ChatSession {
67 config: SessionConfig,
68 client: TranslationClient,
69}
70
71impl ChatSession {
72 pub fn new(config: SessionConfig) -> Self {
74 let client = TranslationClient::new(config.endpoint.clone(), config.api_key.clone());
75 Self { config, client }
76 }
77
78 pub async fn run(&mut self) -> Result<()> {
79 ui::print_header();
80
81 let prompt_style = Styled::new("❯")
82 .with_fg(Color::LightBlue)
83 .with_attr(Attributes::BOLD);
84 let mut render_config = RenderConfig::default()
85 .with_prompt_prefix(prompt_style)
86 .with_answered_prompt_prefix(prompt_style);
87
88 render_config.option = StyleSheet::new().with_fg(Color::Grey);
90 render_config.selected_option = Some(StyleSheet::new().with_fg(Color::DarkMagenta));
92
93 loop {
94 let input = Text::new("")
95 .with_render_config(render_config)
96 .with_autocomplete(SlashCommandCompleter)
97 .with_help_message("Type text to translate, /help for commands, Ctrl+C to quit")
98 .prompt();
99
100 match input {
101 Ok(line) => match parse_input(&line) {
102 Input::Empty => {}
103 Input::Command(cmd) => {
104 if !self.handle_command(cmd) {
105 break;
106 }
107 }
108 Input::Text(text) => {
109 self.translate_and_print(&text).await?;
110 }
111 },
112 Err(
113 inquire::InquireError::OperationCanceled
114 | inquire::InquireError::OperationInterrupted,
115 ) => {
116 println!(); break;
118 }
119 Err(e) => return Err(e.into()),
120 }
121 }
122
123 ui::print_goodbye();
124 Ok(())
125 }
126
127 fn handle_command(&mut self, cmd: SlashCommand) -> bool {
128 match cmd {
129 SlashCommand::Config => {
130 ui::print_config(&self.config);
131 true
132 }
133 SlashCommand::Help => {
134 ui::print_help();
135 true
136 }
137 SlashCommand::Quit => false,
138 SlashCommand::Set { key, value } => {
139 self.handle_set(&key, value.as_deref());
140 true
141 }
142 SlashCommand::Unknown(cmd) => {
143 ui::print_error(&format!("Unknown command: /{cmd}"));
144 true
145 }
146 }
147 }
148
149 fn handle_set(&mut self, key: &str, value: Option<&str>) {
150 match key {
151 "style" => self.set_style(value),
152 "to" => self.set_to(value),
153 "model" => self.set_model(value),
154 "" => {
155 println!("Usage: /set <key> <value>");
156 println!("Keys: style, to, model");
157 }
158 _ => {
159 ui::print_error(&format!("Unknown setting: {key}"));
160 println!("Available: style, to, model");
161 }
162 }
163 }
164
165 fn set_style(&mut self, value: Option<&str>) {
166 let Some(key) = value else {
167 self.config.style_name = None;
169 self.config.style_prompt = None;
170 println!("{} Style cleared", Style::success("✓"));
171 return;
172 };
173
174 let resolved = match style::resolve_style(key, &self.config.custom_styles) {
176 Ok(r) => r,
177 Err(e) => {
178 ui::print_error(&e.to_string());
179 return;
180 }
181 };
182
183 self.config.style_name = Some(key.to_string());
184 self.config.style_prompt = Some(resolved.prompt().to_string());
185 println!(
186 "{} Style set to {}\n",
187 Style::success("✓"),
188 Style::value(key)
189 );
190 }
191
192 fn set_to(&mut self, value: Option<&str>) {
193 match value {
194 None => {
195 ui::print_error("Usage: /set to <language>");
196 }
197 Some(lang) => {
198 self.config.to = lang.to_string();
199 println!(
200 "{} Target language set to {}",
201 Style::success("✓"),
202 Style::value(lang)
203 );
204 }
205 }
206 }
207
208 fn set_model(&mut self, value: Option<&str>) {
209 match value {
210 None => {
211 ui::print_error("Usage: /set model <name>");
212 }
213 Some(model) => {
214 self.config.model = model.to_string();
215 println!(
216 "{} Model set to {}",
217 Style::success("✓"),
218 Style::value(model)
219 );
220 }
221 }
222 }
223
224 async fn translate_and_print(&self, text: &str) -> Result<()> {
225 let request = TranslationRequest {
226 source_text: text.to_string(),
227 target_language: self.config.to.clone(),
228 model: self.config.model.clone(),
229 endpoint: self.config.endpoint.clone(),
230 style: self.config.style_prompt.clone(),
231 };
232
233 let spinner = Spinner::new("Translating...");
234
235 let mut stream = self.client.translate_stream(&request).await?;
236 let mut first_chunk = true;
237
238 while let Some(chunk_result) = stream.next().await {
239 let chunk = chunk_result?;
240
241 if first_chunk {
242 spinner.stop();
243 first_chunk = false;
244 }
245
246 print!("{chunk}");
247 io::stdout().flush()?;
248 }
249
250 if first_chunk {
251 spinner.stop();
252 }
253
254 println!();
255 println!();
256 Ok(())
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264 #[test]
265 fn test_session_config_new() {
266 let mut custom_styles = HashMap::new();
267 custom_styles.insert(
268 "my_style".to_string(),
269 CustomStyle {
270 description: "My description".to_string(),
271 prompt: "My custom prompt".to_string(),
272 },
273 );
274
275 let config = SessionConfig::new(
276 "ollama".to_string(),
277 "http://localhost:11434".to_string(),
278 "gemma3:12b".to_string(),
279 None,
280 "ja".to_string(),
281 Some("casual".to_string()),
282 Some("Use a casual tone.".to_string()),
283 custom_styles,
284 );
285
286 assert_eq!(config.provider_name, "ollama");
287 assert_eq!(config.endpoint, "http://localhost:11434");
288 assert_eq!(config.model, "gemma3:12b");
289 assert!(config.api_key.is_none());
290 assert_eq!(config.to, "ja");
291 assert_eq!(config.style_name, Some("casual".to_string()));
292 assert_eq!(config.style_prompt, Some("Use a casual tone.".to_string()));
293 assert_eq!(
294 config.custom_styles.get("my_style").map(|s| &s.prompt),
295 Some(&"My custom prompt".to_string())
296 );
297 }
298}