Skip to main content

aster/agents/
moim.rs

1use crate::agents::extension_manager::ExtensionManager;
2use crate::conversation::message::Message;
3use crate::conversation::{fix_conversation, Conversation};
4use rmcp::model::Role;
5
6// Test-only utility. Do not use in production code. No `test` directive due to call outside crate.
7thread_local! {
8    pub static SKIP: std::cell::Cell<bool> = const { std::cell::Cell::new(false) };
9}
10
11pub async fn inject_moim(
12    conversation: Conversation,
13    extension_manager: &ExtensionManager,
14) -> Conversation {
15    if SKIP.with(|f| f.get()) {
16        return conversation;
17    }
18
19    if let Some(moim) = extension_manager.collect_moim().await {
20        let mut messages = conversation.messages().clone();
21        let idx = messages
22            .iter()
23            .rposition(|m| m.role == Role::Assistant)
24            .unwrap_or(0);
25        messages.insert(idx, Message::user().with_text(moim));
26
27        let (fixed, issues) = fix_conversation(Conversation::new_unvalidated(messages));
28
29        let has_unexpected_issues = issues.iter().any(|issue| {
30            !issue.contains("Merged consecutive user messages")
31                && !issue.contains("Merged consecutive assistant messages")
32        });
33
34        if has_unexpected_issues {
35            tracing::warn!("MOIM injection caused unexpected issues: {:?}", issues);
36            return conversation;
37        }
38
39        return fixed;
40    }
41    conversation
42}
43
44#[cfg(test)]
45mod tests {
46    use super::*;
47    use rmcp::model::CallToolRequestParam;
48
49    #[tokio::test]
50    async fn test_moim_injection_before_assistant() {
51        let em = ExtensionManager::new_without_provider();
52
53        let conv = Conversation::new_unvalidated(vec![
54            Message::user().with_text("Hello"),
55            Message::assistant().with_text("Hi"),
56            Message::user().with_text("Bye"),
57        ]);
58        let result = inject_moim(conv, &em).await;
59        let msgs = result.messages();
60
61        assert_eq!(msgs.len(), 3);
62        assert_eq!(msgs[0].content[0].as_text().unwrap(), "Hello");
63        assert_eq!(msgs[1].content[0].as_text().unwrap(), "Hi");
64
65        let merged_content = msgs[0]
66            .content
67            .iter()
68            .filter_map(|c| c.as_text())
69            .collect::<Vec<_>>()
70            .join("");
71        assert!(merged_content.contains("Hello"));
72        assert!(merged_content.contains("<info-msg>"));
73    }
74
75    #[tokio::test]
76    async fn test_moim_injection_no_assistant() {
77        let em = ExtensionManager::new_without_provider();
78
79        let conv = Conversation::new_unvalidated(vec![Message::user().with_text("Hello")]);
80        let result = inject_moim(conv, &em).await;
81
82        assert_eq!(result.messages().len(), 1);
83
84        let merged_content = result.messages()[0]
85            .content
86            .iter()
87            .filter_map(|c| c.as_text())
88            .collect::<Vec<_>>()
89            .join("");
90        assert!(merged_content.contains("Hello"));
91        assert!(merged_content.contains("<info-msg>"));
92    }
93
94    #[tokio::test]
95    async fn test_moim_with_tool_calls() {
96        let em = ExtensionManager::new_without_provider();
97
98        let conv = Conversation::new_unvalidated(vec![
99            Message::user().with_text("Search for something"),
100            Message::assistant()
101                .with_text("I'll search for you")
102                .with_tool_request(
103                    "search_1",
104                    Ok(CallToolRequestParam {
105                        name: "search".into(),
106                        arguments: None,
107                    }),
108                ),
109            Message::user().with_tool_response(
110                "search_1",
111                Ok(rmcp::model::CallToolResult {
112                    content: vec![],
113                    structured_content: None,
114                    is_error: Some(false),
115                    meta: None,
116                }),
117            ),
118            Message::assistant()
119                .with_text("I need to search more")
120                .with_tool_request(
121                    "search_2",
122                    Ok(CallToolRequestParam {
123                        name: "search".into(),
124                        arguments: None,
125                    }),
126                ),
127            Message::user().with_tool_response(
128                "search_2",
129                Ok(rmcp::model::CallToolResult {
130                    content: vec![],
131                    structured_content: None,
132                    is_error: Some(false),
133                    meta: None,
134                }),
135            ),
136        ]);
137
138        let result = inject_moim(conv, &em).await;
139        let msgs = result.messages();
140
141        assert_eq!(msgs.len(), 6);
142
143        let moim_msg = &msgs[3];
144        let has_moim = moim_msg
145            .content
146            .iter()
147            .any(|c| c.as_text().is_some_and(|t| t.contains("<info-msg>")));
148
149        assert!(
150            has_moim,
151            "MOIM should be in message before latest assistant message"
152        );
153    }
154}