tl_cli/translation/
client.rs1use anyhow::{Context, Result};
2use bytes::Bytes;
3use futures_util::Stream;
4use reqwest::Client;
5use serde::Serialize;
6use sha2::{Digest, Sha256};
7use std::borrow::Cow;
8use std::pin::Pin;
9
10use super::prompt::{SYSTEM_PROMPT_TEMPLATE, build_system_prompt_with_style};
11use super::sse_parser::sse_to_text_stream;
12
13#[derive(Debug, Clone)]
18pub struct TranslationRequest {
19 pub source_text: String,
21 pub target_language: String,
23 pub model: String,
25 pub endpoint: String,
27 pub style: Option<String>,
29}
30
31impl TranslationRequest {
32 pub fn cache_key(&self) -> String {
37 let prompt_hash = Self::prompt_hash();
38
39 let cache_input = serde_json::json!({
40 "source_text": self.source_text,
41 "target_language": self.target_language,
42 "model": self.model,
43 "endpoint": self.endpoint,
44 "prompt_hash": prompt_hash,
45 "style": self.style
46 });
47
48 let mut hasher = Sha256::new();
49 hasher.update(cache_input.to_string().as_bytes());
50 hex::encode(hasher.finalize())
51 }
52
53 pub fn prompt_hash() -> String {
57 let mut hasher = Sha256::new();
58 hasher.update(SYSTEM_PROMPT_TEMPLATE.as_bytes());
59 hex::encode(hasher.finalize())
60 }
61}
62
63#[derive(Debug, Serialize)]
65struct ChatCompletionRequest<'a> {
66 model: &'a str,
67 messages: Vec<Message<'a>>,
68 stream: bool,
69}
70
71impl<'a> ChatCompletionRequest<'a> {
72 fn for_translation(model: &'a str, system_prompt: &'a str, source_text: &'a str) -> Self {
74 Self {
75 model,
76 messages: vec![
77 Message {
78 role: "system",
79 content: Cow::Borrowed(system_prompt),
80 },
81 Message {
82 role: "user",
83 content: Cow::Borrowed(source_text),
84 },
85 ],
86 stream: true,
87 }
88 }
89}
90
91#[derive(Debug, Serialize)]
92struct Message<'a> {
93 role: &'static str,
94 content: Cow<'a, str>,
95}
96
97pub struct TranslationClient {
129 client: Client,
130 endpoint: String,
131 api_key: Option<String>,
132}
133
134impl TranslationClient {
135 pub fn new(endpoint: String, api_key: Option<String>) -> Self {
137 Self {
138 client: Client::new(),
139 endpoint,
140 api_key,
141 }
142 }
143
144 pub async fn translate_stream(
149 &self,
150 request: &TranslationRequest,
151 ) -> Result<Pin<Box<dyn Stream<Item = Result<String>> + Send>>> {
152 let byte_stream = self
153 .send_chat_completion(
154 &request.model,
155 &request.target_language,
156 &request.source_text,
157 request.style.as_deref(),
158 )
159 .await?;
160
161 Ok(Box::pin(sse_to_text_stream(byte_stream)))
162 }
163
164 async fn send_chat_completion(
166 &self,
167 model: &str,
168 target_language: &str,
169 source_text: &str,
170 style: Option<&str>,
171 ) -> Result<impl Stream<Item = reqwest::Result<Bytes>> + Send + 'static> {
172 let url = self.build_url();
173 let system_prompt = build_system_prompt_with_style(target_language, style);
174 let chat_request =
175 ChatCompletionRequest::for_translation(model, &system_prompt, source_text);
176
177 let response = self.send_request(&url, &chat_request).await?;
178
179 Ok(response.bytes_stream())
180 }
181
182 async fn send_request<T: Serialize + Sync>(
184 &self,
185 url: &str,
186 body: &T,
187 ) -> Result<reqwest::Response> {
188 let mut request = self.client.post(url).json(body);
189
190 if let Some(api_key) = &self.api_key {
191 request = request.header("Authorization", format!("Bearer {api_key}"));
192 }
193
194 let response = request
195 .send()
196 .await
197 .with_context(|| format!("Failed to connect to API endpoint: {url}"))?;
198
199 if !response.status().is_success() {
200 let status = response.status();
201 let body = response.text().await.unwrap_or_default();
202 anyhow::bail!("API request failed with status {status}: {body}");
203 }
204
205 Ok(response)
206 }
207
208 fn build_url(&self) -> String {
210 format!(
211 "{}/v1/chat/completions",
212 self.endpoint.trim_end_matches('/')
213 )
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 fn create_test_request() -> TranslationRequest {
222 TranslationRequest {
223 source_text: "Hello, world!".to_string(),
224 target_language: "ja".to_string(),
225 model: "gemma3:12b".to_string(),
226 endpoint: "http://localhost:11434".to_string(),
227 style: None,
228 }
229 }
230
231 #[test]
232 fn test_cache_key_is_consistent() {
233 let request = create_test_request();
234 let key1 = request.cache_key();
235 let key2 = request.cache_key();
236 assert_eq!(key1, key2);
237 }
238
239 #[test]
240 fn test_cache_key_is_hex_string() {
241 let request = create_test_request();
242 let key = request.cache_key();
243 assert_eq!(key.len(), 64);
245 assert!(key.chars().all(|c| c.is_ascii_hexdigit()));
246 }
247
248 #[test]
249 fn test_cache_key_differs_for_different_source_text() {
250 let request1 = create_test_request();
251 let mut request2 = create_test_request();
252 request2.source_text = "Different text".to_string();
253 assert_ne!(request1.cache_key(), request2.cache_key());
254 }
255
256 #[test]
257 fn test_cache_key_differs_for_different_target_language() {
258 let request1 = create_test_request();
259 let mut request2 = create_test_request();
260 request2.target_language = "en".to_string();
261 assert_ne!(request1.cache_key(), request2.cache_key());
262 }
263
264 #[test]
265 fn test_cache_key_differs_for_different_model() {
266 let request1 = create_test_request();
267 let mut request2 = create_test_request();
268 request2.model = "gpt-4o".to_string();
269 assert_ne!(request1.cache_key(), request2.cache_key());
270 }
271
272 #[test]
273 fn test_cache_key_differs_for_different_endpoint() {
274 let request1 = create_test_request();
275 let mut request2 = create_test_request();
276 request2.endpoint = "https://api.openai.com".to_string();
277 assert_ne!(request1.cache_key(), request2.cache_key());
278 }
279
280 #[test]
281 fn test_prompt_hash_is_consistent() {
282 let hash1 = TranslationRequest::prompt_hash();
283 let hash2 = TranslationRequest::prompt_hash();
284 assert_eq!(hash1, hash2);
285 }
286
287 #[test]
288 fn test_prompt_hash_is_hex_string() {
289 let hash = TranslationRequest::prompt_hash();
290 assert_eq!(hash.len(), 64);
292 assert!(hash.chars().all(|c| c.is_ascii_hexdigit()));
293 }
294
295 #[test]
296 fn test_translation_client_new() {
297 let client = TranslationClient::new(
298 "http://localhost:11434".to_string(),
299 Some("test-api-key".to_string()),
300 );
301 assert_eq!(client.endpoint, "http://localhost:11434");
302 assert_eq!(client.api_key, Some("test-api-key".to_string()));
303 }
304
305 #[test]
306 fn test_translation_client_new_without_api_key() {
307 let client = TranslationClient::new("http://localhost:11434".to_string(), None);
308 assert_eq!(client.endpoint, "http://localhost:11434");
309 assert!(client.api_key.is_none());
310 }
311
312 #[test]
313 fn test_build_url_without_trailing_slash() {
314 let client = TranslationClient::new("http://localhost:11434".to_string(), None);
315 assert_eq!(
316 client.build_url(),
317 "http://localhost:11434/v1/chat/completions"
318 );
319 }
320
321 #[test]
322 fn test_build_url_with_trailing_slash() {
323 let client = TranslationClient::new("http://localhost:11434/".to_string(), None);
324 assert_eq!(
325 client.build_url(),
326 "http://localhost:11434/v1/chat/completions"
327 );
328 }
329}