Skip to main content

microagents_core/
common.rs

1use std::sync::Arc;
2#[cfg(feature = "token_estimation")]
3use std::sync::OnceLock;
4
5use microagents_events::{AgentEventAny, types::ToolResult};
6use serde_json::Value;
7use ultrafast_models_sdk::{
8    Message, Role,
9    models::{FunctionCall, ToolCall},
10};
11
12use crate::types::{AgentError, ToolExecutionContext, ToolFunction};
13
14#[cfg(feature = "token_estimation")]
15static TOKENIZER: OnceLock<Result<tokie::Tokenizer, tokie::HubError>> = OnceLock::new();
16
17#[cfg(feature = "token_estimation")]
18fn tokenizer() -> &'static Result<tokie::Tokenizer, tokie::HubError> {
19    TOKENIZER.get_or_init(|| tokie::Tokenizer::from_pretrained("gpt2"))
20}
21
22/// Verify that an environment variable containing an API key is set.
23///
24/// Returns `Ok(())` if the variable exists, otherwise propagates the [`VarError`].
25pub fn check_api_key(api_key: &str) -> Result<(), std::env::VarError> {
26    let _ = std::env::var(api_key)?;
27    Ok(())
28}
29
30/// Convert a persisted [`AgentEventAny`] back into an SDK [`Message`].
31///
32/// Only events that correspond to chat roles (`User`, `Assistant`, `Tool`)
33/// produce a message. All other variants return [`None`].
34pub fn convert_event_to_message(event: AgentEventAny) -> Option<Message> {
35    match event {
36        AgentEventAny::UserPromptSubmit(p) => Some(Message {
37            role: Role::User,
38            content: p.prompt,
39            name: None,
40            tool_calls: None,
41            tool_call_id: None,
42        }),
43        AgentEventAny::AssistantResponse(p) => {
44            let msg = if let Some(tc) = p.tool_calls {
45                let calls: Vec<ToolCall> = tc
46                    .iter()
47                    .map(|t| ToolCall {
48                        call_type: t.call_type.clone(),
49                        id: t.id.clone(),
50                        function: FunctionCall {
51                            name: t.function.name.clone(),
52                            arguments: t.function.arguments.clone(),
53                        },
54                    })
55                    .collect();
56                Message {
57                    role: Role::Assistant,
58                    content: p.full_text,
59                    name: None,
60                    tool_calls: Some(calls),
61                    tool_call_id: None,
62                }
63            } else {
64                Message {
65                    role: Role::Assistant,
66                    content: p.full_text,
67                    name: None,
68                    tool_calls: None,
69                    tool_call_id: None,
70                }
71            };
72            Some(msg)
73        }
74        AgentEventAny::ToolResult(p) => {
75            let result = match p.result {
76                ToolResult::Ok(r) => format!("Tool call succeeded: {}", r),
77                ToolResult::Err(r) => format!("Tool call failed: {}", r),
78                _ => unreachable!("ToolResult should not reach this branch"),
79            };
80            Some(Message {
81                role: Role::Tool,
82                content: result,
83                name: None,
84                tool_calls: None,
85                tool_call_id: Some(p.tool_call_id),
86            })
87        }
88        _ => None,
89    }
90}
91
92/// Result of attempting to parse a (potentially partial) JSON string.
93pub enum JsonResult {
94    /// Fully valid JSON value.
95    Valid(Value),
96    /// The input is a valid prefix but truncated (EOF while parsing).
97    Incomplete,
98    /// The input is not valid JSON.
99    Malformed,
100}
101
102/// Parse a JSON string that may be incomplete (e.g. streaming tool arguments).
103///
104/// Returns [`JsonResult::Incomplete`] when the payload is cut off mid-token,
105/// allowing the caller to buffer and retry.
106pub fn parse_json_fragment(s: &str) -> JsonResult {
107    let v = serde_json::from_str::<Value>(s);
108    match v {
109        Ok(val) => JsonResult::Valid(val),
110        Err(e) => {
111            if e.is_eof() {
112                return JsonResult::Incomplete;
113            }
114            JsonResult::Malformed
115        }
116    }
117}
118
119/// Validate tool arguments against its JSON schema and then execute it.
120///
121/// This is the canonical entry-point for invoking a [`ToolFunction`] from the
122/// agent runtime. It first checks schema conformance with `jsonschema`, then
123/// calls [`ToolFunction::execute`].
124pub async fn call_tool<Ctx: Send + Sync + 'static>(
125    tool: Arc<dyn ToolFunction<Ctx>>,
126    tool_args: Value,
127    tool_context: Arc<ToolExecutionContext<Ctx>>,
128) -> Result<ToolResult, AgentError> {
129    jsonschema::validate(&tool.input_schema(), &tool_args)
130        .map_err(|e| AgentError::ToolCallError(e.to_string()))?;
131    let result = tool.execute(tool_args, &tool_context).await?;
132    Ok(result)
133}
134
135/// Estimate the number of tokens in a given text using the GPT-2 tokenizer.
136/// Requires the `token_estimation` feature. Returns 0 if the feature is disabled.
137pub fn estimate_tokens(_text: &str) -> Result<usize, AgentError> {
138    #[cfg(feature = "token_estimation")]
139    {
140        Ok(tokenizer()
141            .as_ref()
142            .map_err(|e| AgentError::TokenizerLoadingError(e.to_string()))?
143            .count_tokens(_text))
144    }
145    #[cfg(not(feature = "token_estimation"))]
146    {
147        Ok(0)
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use chrono::Utc;
155    use microagents_events::{
156        AssistantResponseEvent, SessionInitEvent, SessionInitType, SessionStopEvent,
157        SkillLoadEvent, StreamDeltaEvent, ToolCallEvent, ToolResultEvent, Usage,
158        UserPromptSubmitEvent,
159        types::{FunctionCall as EventFunctionCall, ToolCall as EventToolCall},
160    };
161
162    #[test]
163    fn test_convert_user_prompt_submit() {
164        let event = AgentEventAny::UserPromptSubmit(UserPromptSubmitEvent {
165            session_id: "s1".into(),
166            turn_id: "t1".into(),
167            prompt: "hello".into(),
168            timestamp: Utc::now(),
169        });
170        let msg = convert_event_to_message(event).unwrap();
171        assert_eq!(msg.role, Role::User);
172        assert_eq!(msg.content, "hello");
173        assert!(msg.tool_calls.is_none());
174        assert!(msg.tool_call_id.is_none());
175    }
176
177    #[test]
178    fn test_convert_assistant_response_without_tool_calls() {
179        let event = AgentEventAny::AssistantResponse(AssistantResponseEvent {
180            session_id: "s1".into(),
181            turn_id: "t1".into(),
182            full_text: "hi there".into(),
183            tool_calls: None,
184            timestamp: Utc::now(),
185        });
186        let msg = convert_event_to_message(event).unwrap();
187        assert_eq!(msg.role, Role::Assistant);
188        assert_eq!(msg.content, "hi there");
189        assert!(msg.tool_calls.is_none());
190    }
191
192    #[test]
193    fn test_convert_assistant_response_with_tool_calls() {
194        let event = AgentEventAny::AssistantResponse(AssistantResponseEvent {
195            session_id: "s1".into(),
196            turn_id: "t1".into(),
197            full_text: "calling tool".into(),
198            tool_calls: Some(vec![EventToolCall {
199                id: "tc1".into(),
200                call_type: "function".into(),
201                function: EventFunctionCall {
202                    name: "my_tool".into(),
203                    arguments: "{\"x\":1}".into(),
204                },
205            }]),
206            timestamp: Utc::now(),
207        });
208        let msg = convert_event_to_message(event).unwrap();
209        assert_eq!(msg.role, Role::Assistant);
210        let calls = msg.tool_calls.unwrap();
211        assert_eq!(calls.len(), 1);
212        assert_eq!(calls[0].id, "tc1");
213        assert_eq!(calls[0].function.name, "my_tool");
214        assert_eq!(calls[0].function.arguments, "{\"x\":1}");
215    }
216
217    #[test]
218    fn test_convert_tool_result_ok() {
219        let event = AgentEventAny::ToolResult(ToolResultEvent {
220            session_id: "s1".into(),
221            turn_id: "t1".into(),
222            result: ToolResult::Ok("done".into()),
223            tool_call_id: "tc1".into(),
224            timestamp: Utc::now(),
225        });
226        let msg = convert_event_to_message(event).unwrap();
227        assert_eq!(msg.role, Role::Tool);
228        assert_eq!(msg.content, "Tool call succeeded: done");
229        assert_eq!(msg.tool_call_id, Some("tc1".into()));
230    }
231
232    #[test]
233    fn test_convert_tool_result_err() {
234        let event = AgentEventAny::ToolResult(ToolResultEvent {
235            session_id: "s1".into(),
236            turn_id: "t1".into(),
237            result: ToolResult::Err("oops".into()),
238            tool_call_id: "tc2".into(),
239            timestamp: Utc::now(),
240        });
241        let msg = convert_event_to_message(event).unwrap();
242        assert_eq!(msg.role, Role::Tool);
243        assert_eq!(msg.content, "Tool call failed: oops");
244        assert_eq!(msg.tool_call_id, Some("tc2".into()));
245    }
246
247    #[test]
248    fn test_convert_other_events_return_none() {
249        assert!(
250            convert_event_to_message(AgentEventAny::SessionInit(SessionInitEvent {
251                session_id: "s1".into(),
252                model: "m".into(),
253                provider: "p".into(),
254                system: "sys".into(),
255                init_type: SessionInitType::Start,
256                timestamp: Utc::now(),
257            }))
258            .is_none()
259        );
260
261        assert!(
262            convert_event_to_message(AgentEventAny::SessionStop(SessionStopEvent {
263                session_id: "s1".into(),
264                success: true,
265                result: None,
266                error: None,
267                timestamp: Utc::now(),
268                usage: Usage::default()
269            }))
270            .is_none()
271        );
272
273        assert!(
274            convert_event_to_message(AgentEventAny::StreamDelta(StreamDeltaEvent {
275                session_id: "s1".into(),
276                turn_id: "t1".into(),
277                delta: "d".into(),
278                delta_type: microagents_events::DeltaType::Text,
279                timestamp: Utc::now(),
280            }))
281            .is_none()
282        );
283
284        assert!(
285            convert_event_to_message(AgentEventAny::ToolCall(ToolCallEvent {
286                session_id: "s1".into(),
287                turn_id: "t1".into(),
288                name: "tool".into(),
289                input: Value::Null,
290                timestamp: Utc::now(),
291            }))
292            .is_none()
293        );
294
295        assert!(
296            convert_event_to_message(AgentEventAny::SkillLoad(SkillLoadEvent {
297                session_id: "s1".into(),
298                turn_id: "t1".into(),
299                skill_name: "skill".into(),
300                timestamp: Utc::now(),
301            }))
302            .is_none()
303        );
304    }
305
306    #[test]
307    fn test_parse_json_fragment_valid() {
308        match parse_json_fragment(r#"{"key": "value"}"#) {
309            JsonResult::Valid(v) => assert_eq!(v["key"], "value"),
310            _ => panic!("expected Valid"),
311        }
312    }
313
314    #[test]
315    fn test_parse_json_fragment_incomplete() {
316        match parse_json_fragment(r#"{"key": "val""#) {
317            JsonResult::Incomplete => {}
318            _ => panic!("expected Incomplete"),
319        }
320    }
321
322    #[test]
323    fn test_parse_json_fragment_malformed() {
324        match parse_json_fragment(r#"{"key": "value",}"#) {
325            JsonResult::Malformed => {}
326            _ => panic!("expected Malformed"),
327        }
328    }
329
330    #[derive(Debug)]
331    struct DummyTool {
332        schema: Value,
333    }
334
335    #[async_trait::async_trait]
336    impl ToolFunction<()> for DummyTool {
337        fn name(&self) -> &'static str {
338            "dummy"
339        }
340        fn description(&self) -> &'static str {
341            "desc"
342        }
343        fn input_schema(&self) -> Value {
344            self.schema.clone()
345        }
346        async fn execute(
347            &self,
348            _input: Value,
349            _ctx: &Arc<ToolExecutionContext<()>>,
350        ) -> Result<ToolResult, AgentError> {
351            Ok(ToolResult::Ok("ok".into()))
352        }
353    }
354
355    #[tokio::test]
356    async fn test_call_tool_validates_and_executes() {
357        let schema = serde_json::json!({
358            "type": "object",
359            "properties": {
360                "name": { "type": "string" }
361            },
362            "required": ["name"]
363        });
364        let tool = Arc::new(DummyTool { schema });
365        let ctx = Arc::new(ToolExecutionContext::new(()));
366        let args = serde_json::json!({"name": "world"});
367        let result = call_tool(tool, args, ctx).await.unwrap();
368        assert!(matches!(result, ToolResult::Ok(ref s) if s == "ok"));
369    }
370
371    #[tokio::test]
372    async fn test_call_tool_schema_validation_fails() {
373        let schema = serde_json::json!({
374            "type": "object",
375            "properties": {
376                "count": { "type": "integer" }
377            },
378            "required": ["count"]
379        });
380        let tool = Arc::new(DummyTool { schema });
381        let ctx = Arc::new(ToolExecutionContext::new(()));
382        let args = serde_json::json!({"count": "not a number"});
383        let err = call_tool(tool, args, ctx).await.unwrap_err();
384        match err {
385            AgentError::ToolCallError(_) => {}
386            other => panic!("expected ToolCallError, got {:?}", other),
387        }
388    }
389
390    #[test]
391    #[cfg(feature = "token_estimation")]
392    fn test_estimate_tokens() {
393        let count = estimate_tokens("hello world").expect("Should be able to estimate tokens");
394        assert_eq!(count, 2);
395    }
396
397    #[test]
398    #[cfg(not(feature = "token_estimation"))]
399    fn test_estimate_tokens() {
400        let count = estimate_tokens("hello world").expect("Should be able to estimate tokens");
401        assert_eq!(count, 0);
402    }
403}