Skip to main content

ai_agents_hitl/
localization.rs

1use serde_json::Value;
2use std::collections::HashMap;
3
4use ai_agents_core::{AgentError, Result};
5use ai_agents_llm::{ChatMessage, LLMRegistry};
6
7use super::config::{
8    ApprovalMessage, LlmGenerateConfig, MessageLanguageConfig, MessageLanguageStrategy,
9    ToolApprovalConfig,
10};
11use super::handler::ApprovalHandler;
12
13pub struct MessageResolver<'a> {
14    global_config: &'a MessageLanguageConfig,
15    llm_registry: Option<&'a LLMRegistry>,
16}
17
18impl<'a> MessageResolver<'a> {
19    pub fn new(global_config: &'a MessageLanguageConfig) -> Self {
20        Self {
21            global_config,
22            llm_registry: None,
23        }
24    }
25
26    pub fn with_llm_registry(mut self, registry: &'a LLMRegistry) -> Self {
27        self.llm_registry = Some(registry);
28        self
29    }
30
31    pub async fn resolve(
32        &self,
33        approval_message: &ApprovalMessage,
34        local_config: Option<&MessageLanguageConfig>,
35        context: &HashMap<String, Value>,
36        handler: &dyn ApprovalHandler,
37    ) -> Result<String> {
38        let effective_config = local_config.unwrap_or(self.global_config);
39        let strategies = std::iter::once(effective_config.strategy.clone())
40            .chain(effective_config.fallback.iter().cloned());
41
42        for strategy in strategies {
43            if let Some(message) = self
44                .try_strategy(
45                    strategy,
46                    approval_message,
47                    effective_config,
48                    context,
49                    handler,
50                )
51                .await?
52            {
53                return Ok(render_template(&message, context));
54            }
55        }
56
57        // Final fallback: use the raw message and render templates
58        let raw = approval_message
59            .get_any()
60            .or_else(|| approval_message.description().map(String::from))
61            .unwrap_or_else(|| "Approval required".to_string());
62        Ok(render_template(&raw, context))
63    }
64
65    async fn try_strategy(
66        &self,
67        strategy: MessageLanguageStrategy,
68        approval_message: &ApprovalMessage,
69        config: &MessageLanguageConfig,
70        context: &HashMap<String, Value>,
71        handler: &dyn ApprovalHandler,
72    ) -> Result<Option<String>> {
73        match strategy {
74            MessageLanguageStrategy::Auto => Ok(None),
75
76            MessageLanguageStrategy::Approver => {
77                if let Some(lang) = handler.preferred_language() {
78                    return Ok(approval_message.get(&lang));
79                }
80                Ok(None)
81            }
82
83            MessageLanguageStrategy::User => {
84                let lang = get_user_language(context);
85                if let Some(lang) = lang {
86                    return Ok(approval_message.get(&lang));
87                }
88                Ok(None)
89            }
90
91            MessageLanguageStrategy::Explicit => {
92                if let Some(ref lang) = config.explicit {
93                    return Ok(approval_message.get(lang));
94                }
95                Ok(None)
96            }
97
98            MessageLanguageStrategy::LlmGenerate => {
99                if let Some(registry) = self.llm_registry {
100                    match self
101                        .generate_message_with_llm(
102                            approval_message,
103                            config.llm_generate.as_ref(),
104                            context,
105                            handler,
106                            registry,
107                        )
108                        .await
109                    {
110                        Ok(message) => return Ok(Some(message)),
111                        Err(_) => return Ok(None),
112                    }
113                }
114                Ok(None)
115            }
116        }
117    }
118
119    async fn generate_message_with_llm(
120        &self,
121        approval_message: &ApprovalMessage,
122        llm_config: Option<&LlmGenerateConfig>,
123        context: &HashMap<String, Value>,
124        handler: &dyn ApprovalHandler,
125        registry: &LLMRegistry,
126    ) -> Result<String> {
127        let config = llm_config.cloned().unwrap_or_default();
128
129        let target_lang = handler
130            .preferred_language()
131            .or_else(|| get_user_language(context))
132            .unwrap_or_else(|| "English".to_string());
133
134        let description = approval_message
135            .description()
136            .map(String::from)
137            .or_else(|| approval_message.get_any())
138            .unwrap_or_else(|| "Approval required for this action".to_string());
139
140        let context_str = if config.include_context && !context.is_empty() {
141            format!(
142                "\nContext: {}",
143                serde_json::to_string_pretty(context).unwrap_or_default()
144            )
145        } else {
146            String::new()
147        };
148
149        let prompt = format!(
150            "Generate a clear, concise approval request message in {}.\n\
151             Action: {}{}\n\
152             Requirements:\n\
153             - Keep it under 100 words\n\
154             - Be direct and professional\n\
155             - Include the actual context values in the message, not placeholders\n\
156             - Output only the message text, no explanations",
157            target_lang, description, context_str
158        );
159
160        let llm = registry.get(&config.llm).map_err(|e| {
161            AgentError::Other(format!("Failed to get LLM for message generation: {}", e))
162        })?;
163
164        let response = llm
165            .complete(&[ChatMessage::user(&prompt)], None)
166            .await
167            .map_err(|e| AgentError::Other(format!("LLM generation failed: {}", e)))?;
168
169        Ok(response.content.trim().to_string())
170    }
171}
172
173fn get_user_language(context: &HashMap<String, Value>) -> Option<String> {
174    context
175        .get("user.language")
176        .and_then(|v| v.as_str())
177        .or_else(|| {
178            context
179                .get("input.detected.language")
180                .and_then(|v| v.as_str())
181        })
182        .or_else(|| context.get("language").and_then(|v| v.as_str()))
183        .map(String::from)
184}
185
186pub(crate) fn render_template(template: &str, context: &HashMap<String, Value>) -> String {
187    let mut result = template.to_string();
188
189    for (key, value) in context {
190        let placeholder = format!("{{{{ {} }}}}", key);
191        let replacement = match value {
192            Value::String(s) => s.clone(),
193            Value::Number(n) => n.to_string(),
194            Value::Bool(b) => b.to_string(),
195            Value::Null => "null".to_string(),
196            _ => serde_json::to_string(value).unwrap_or_default(),
197        };
198        result = result.replace(&placeholder, &replacement);
199    }
200
201    result
202}
203
204pub fn resolve_best_language(
205    approval_message: &ApprovalMessage,
206    handler: &dyn ApprovalHandler,
207    context: &HashMap<String, Value>,
208) -> Option<String> {
209    if let Some(lang) = handler.preferred_language() {
210        if approval_message.get(&lang).is_some() {
211            return Some(lang);
212        }
213    }
214
215    if let Some(lang) = get_user_language(context) {
216        if approval_message.get(&lang).is_some() {
217            return Some(lang);
218        }
219    }
220
221    if approval_message.get("en").is_some() {
222        return Some("en".to_string());
223    }
224
225    approval_message.available_languages().into_iter().next()
226}
227
228pub async fn resolve_tool_message(
229    tool_config: &ToolApprovalConfig,
230    tool_name: &str,
231    global_config: &MessageLanguageConfig,
232    context: &HashMap<String, Value>,
233    handler: &dyn ApprovalHandler,
234    llm_registry: Option<&LLMRegistry>,
235) -> Result<String> {
236    if tool_config.approval_message.is_empty() {
237        return Ok(format!("Approve execution of tool '{}'?", tool_name));
238    }
239
240    let mut resolver = MessageResolver::new(global_config);
241    if let Some(registry) = llm_registry {
242        resolver = resolver.with_llm_registry(registry);
243    }
244
245    resolver
246        .resolve(
247            &tool_config.approval_message,
248            tool_config.message_language.as_ref(),
249            context,
250            handler,
251        )
252        .await
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use crate::types::ApprovalResult;
259
260    struct TestHandler {
261        language: Option<String>,
262    }
263
264    impl TestHandler {
265        fn new() -> Self {
266            Self { language: None }
267        }
268
269        fn with_language(language: impl Into<String>) -> Self {
270            Self {
271                language: Some(language.into()),
272            }
273        }
274    }
275
276    #[async_trait::async_trait]
277    impl ApprovalHandler for TestHandler {
278        async fn request_approval(
279            &self,
280            _request: crate::types::ApprovalRequest,
281        ) -> ApprovalResult {
282            ApprovalResult::Approved
283        }
284
285        fn preferred_language(&self) -> Option<String> {
286            self.language.clone()
287        }
288    }
289
290    #[test]
291    fn test_render_template_basic() {
292        let mut context = HashMap::new();
293        context.insert("amount".to_string(), Value::Number(1000.into()));
294        context.insert("currency".to_string(), Value::String("USD".to_string()));
295
296        let template = "Approve {{ amount }} {{ currency }}?";
297        let result = render_template(template, &context);
298        assert_eq!(result, "Approve 1000 USD?");
299    }
300
301    #[test]
302    fn test_render_template_no_placeholders() {
303        let context = HashMap::new();
304        let template = "Simple message";
305        let result = render_template(template, &context);
306        assert_eq!(result, "Simple message");
307    }
308
309    #[test]
310    fn test_get_user_language_from_user_language() {
311        let mut context = HashMap::new();
312        context.insert("user.language".to_string(), Value::String("ko".to_string()));
313
314        assert_eq!(get_user_language(&context), Some("ko".to_string()));
315    }
316
317    #[test]
318    fn test_get_user_language_from_detected() {
319        let mut context = HashMap::new();
320        context.insert(
321            "input.detected.language".to_string(),
322            Value::String("ja".to_string()),
323        );
324
325        assert_eq!(get_user_language(&context), Some("ja".to_string()));
326    }
327
328    #[test]
329    fn test_get_user_language_priority() {
330        let mut context = HashMap::new();
331        context.insert("user.language".to_string(), Value::String("ko".to_string()));
332        context.insert(
333            "input.detected.language".to_string(),
334            Value::String("ja".to_string()),
335        );
336
337        assert_eq!(get_user_language(&context), Some("ko".to_string()));
338    }
339
340    #[test]
341    fn test_get_user_language_none() {
342        let context = HashMap::new();
343        assert_eq!(get_user_language(&context), None);
344    }
345
346    #[tokio::test]
347    async fn test_resolve_approver_strategy() {
348        let config = MessageLanguageConfig {
349            strategy: MessageLanguageStrategy::Approver,
350            fallback: vec![],
351            explicit: None,
352            llm_generate: None,
353        };
354
355        let mut messages = HashMap::new();
356        messages.insert("en".to_string(), "English message".to_string());
357        messages.insert("ko".to_string(), "한국어 메시지".to_string());
358        let approval_message = ApprovalMessage::multi_language(messages);
359
360        let handler = TestHandler::with_language("ko");
361        let context = HashMap::new();
362
363        let resolver = MessageResolver::new(&config);
364        let result = resolver
365            .resolve(&approval_message, None, &context, &handler)
366            .await
367            .unwrap();
368
369        assert_eq!(result, "한국어 메시지");
370    }
371
372    #[tokio::test]
373    async fn test_resolve_user_strategy() {
374        let config = MessageLanguageConfig {
375            strategy: MessageLanguageStrategy::User,
376            fallback: vec![],
377            explicit: None,
378            llm_generate: None,
379        };
380
381        let mut messages = HashMap::new();
382        messages.insert("en".to_string(), "English message".to_string());
383        messages.insert("ja".to_string(), "日本語メッセージ".to_string());
384        let approval_message = ApprovalMessage::multi_language(messages);
385
386        let handler = TestHandler::new();
387        let mut context = HashMap::new();
388        context.insert("user.language".to_string(), Value::String("ja".to_string()));
389
390        let resolver = MessageResolver::new(&config);
391        let result = resolver
392            .resolve(&approval_message, None, &context, &handler)
393            .await
394            .unwrap();
395
396        assert_eq!(result, "日本語メッセージ");
397    }
398
399    #[tokio::test]
400    async fn test_resolve_explicit_strategy() {
401        let config = MessageLanguageConfig {
402            strategy: MessageLanguageStrategy::Explicit,
403            fallback: vec![],
404            explicit: Some("en".to_string()),
405            llm_generate: None,
406        };
407
408        let mut messages = HashMap::new();
409        messages.insert("en".to_string(), "English message".to_string());
410        messages.insert("ko".to_string(), "한국어 메시지".to_string());
411        let approval_message = ApprovalMessage::multi_language(messages);
412
413        let handler = TestHandler::with_language("ko");
414        let context = HashMap::new();
415
416        let resolver = MessageResolver::new(&config);
417        let result = resolver
418            .resolve(&approval_message, None, &context, &handler)
419            .await
420            .unwrap();
421
422        assert_eq!(result, "English message");
423    }
424
425    #[tokio::test]
426    async fn test_resolve_fallback_chain() {
427        let config = MessageLanguageConfig {
428            strategy: MessageLanguageStrategy::Approver,
429            fallback: vec![
430                MessageLanguageStrategy::User,
431                MessageLanguageStrategy::Explicit,
432            ],
433            explicit: Some("en".to_string()),
434            llm_generate: None,
435        };
436
437        let mut messages = HashMap::new();
438        messages.insert("en".to_string(), "English fallback".to_string());
439        let approval_message = ApprovalMessage::multi_language(messages);
440
441        let handler = TestHandler::with_language("ko");
442        let context = HashMap::new();
443
444        let resolver = MessageResolver::new(&config);
445        let result = resolver
446            .resolve(&approval_message, None, &context, &handler)
447            .await
448            .unwrap();
449
450        assert_eq!(result, "English fallback");
451    }
452
453    #[tokio::test]
454    async fn test_resolve_simple_message() {
455        let config = MessageLanguageConfig::default();
456        let approval_message = ApprovalMessage::simple("Simple approval message");
457
458        let handler = TestHandler::with_language("ko");
459        let context = HashMap::new();
460
461        let resolver = MessageResolver::new(&config);
462        let result = resolver
463            .resolve(&approval_message, None, &context, &handler)
464            .await
465            .unwrap();
466
467        assert_eq!(result, "Simple approval message");
468    }
469
470    #[tokio::test]
471    async fn test_resolve_with_template() {
472        let config = MessageLanguageConfig {
473            strategy: MessageLanguageStrategy::Explicit,
474            fallback: vec![],
475            explicit: Some("en".to_string()),
476            llm_generate: None,
477        };
478
479        let mut messages = HashMap::new();
480        messages.insert(
481            "en".to_string(),
482            "Approve {{ amount }} {{ currency }}?".to_string(),
483        );
484        let approval_message = ApprovalMessage::multi_language(messages);
485
486        let handler = TestHandler::new();
487        let mut context = HashMap::new();
488        context.insert("amount".to_string(), Value::Number(500.into()));
489        context.insert("currency".to_string(), Value::String("USD".to_string()));
490
491        let resolver = MessageResolver::new(&config);
492        let result = resolver
493            .resolve(&approval_message, None, &context, &handler)
494            .await
495            .unwrap();
496
497        assert_eq!(result, "Approve 500 USD?");
498    }
499
500    #[tokio::test]
501    async fn test_resolve_local_config_override() {
502        let global_config = MessageLanguageConfig {
503            strategy: MessageLanguageStrategy::Approver,
504            fallback: vec![],
505            explicit: None,
506            llm_generate: None,
507        };
508
509        let local_config = MessageLanguageConfig {
510            strategy: MessageLanguageStrategy::Explicit,
511            fallback: vec![],
512            explicit: Some("ja".to_string()),
513            llm_generate: None,
514        };
515
516        let mut messages = HashMap::new();
517        messages.insert("en".to_string(), "English".to_string());
518        messages.insert("ja".to_string(), "日本語".to_string());
519        let approval_message = ApprovalMessage::multi_language(messages);
520
521        let handler = TestHandler::with_language("en");
522        let context = HashMap::new();
523
524        let resolver = MessageResolver::new(&global_config);
525        let result = resolver
526            .resolve(&approval_message, Some(&local_config), &context, &handler)
527            .await
528            .unwrap();
529
530        assert_eq!(result, "日本語");
531    }
532
533    #[test]
534    fn test_resolve_best_language_approver_preferred() {
535        let mut messages = HashMap::new();
536        messages.insert("en".to_string(), "English".to_string());
537        messages.insert("ko".to_string(), "한국어".to_string());
538        let approval_message = ApprovalMessage::multi_language(messages);
539
540        let handler = TestHandler::with_language("ko");
541        let context = HashMap::new();
542
543        let result = resolve_best_language(&approval_message, &handler, &context);
544        assert_eq!(result, Some("ko".to_string()));
545    }
546
547    #[test]
548    fn test_resolve_best_language_user_fallback() {
549        let mut messages = HashMap::new();
550        messages.insert("en".to_string(), "English".to_string());
551        messages.insert("ja".to_string(), "日本語".to_string());
552        let approval_message = ApprovalMessage::multi_language(messages);
553
554        let handler = TestHandler::with_language("ko");
555        let mut context = HashMap::new();
556        context.insert("user.language".to_string(), Value::String("ja".to_string()));
557
558        let result = resolve_best_language(&approval_message, &handler, &context);
559        assert_eq!(result, Some("ja".to_string()));
560    }
561
562    #[test]
563    fn test_resolve_best_language_english_default() {
564        let mut messages = HashMap::new();
565        messages.insert("en".to_string(), "English".to_string());
566        messages.insert("fr".to_string(), "Français".to_string());
567        let approval_message = ApprovalMessage::multi_language(messages);
568
569        let handler = TestHandler::with_language("ko");
570        let context = HashMap::new();
571
572        let result = resolve_best_language(&approval_message, &handler, &context);
573        assert_eq!(result, Some("en".to_string()));
574    }
575
576    #[tokio::test]
577    async fn test_resolve_tool_message_empty() {
578        let tool_config = ToolApprovalConfig::default();
579        let global_config = MessageLanguageConfig::default();
580        let context = HashMap::new();
581        let handler = TestHandler::new();
582
583        let result = resolve_tool_message(
584            &tool_config,
585            "test_tool",
586            &global_config,
587            &context,
588            &handler,
589            None,
590        )
591        .await
592        .unwrap();
593
594        assert_eq!(result, "Approve execution of tool 'test_tool'?");
595    }
596
597    #[tokio::test]
598    async fn test_resolve_tool_message_with_config() {
599        let mut messages = HashMap::new();
600        messages.insert("en".to_string(), "Approve test?".to_string());
601
602        let tool_config = ToolApprovalConfig {
603            require_approval: true,
604            approval_context: vec![],
605            approval_message: ApprovalMessage::multi_language(messages),
606            message_language: Some(MessageLanguageConfig {
607                strategy: MessageLanguageStrategy::Explicit,
608                fallback: vec![],
609                explicit: Some("en".to_string()),
610                llm_generate: None,
611            }),
612            timeout_seconds: None,
613        };
614
615        let global_config = MessageLanguageConfig::default();
616        let context = HashMap::new();
617        let handler = TestHandler::new();
618
619        let result = resolve_tool_message(
620            &tool_config,
621            "test_tool",
622            &global_config,
623            &context,
624            &handler,
625            None,
626        )
627        .await
628        .unwrap();
629
630        assert_eq!(result, "Approve test?");
631    }
632}