1use std::sync::Arc;
2#[cfg(feature = "token_estimation")]
3use std::sync::OnceLock;
4
5use microagents_events::{AgentEventAny, types::ToolResult};
6use serde_json::Value;
7use ultrafast_models_sdk::{
8 Message, Role,
9 models::{FunctionCall, ToolCall},
10};
11
12use crate::types::{AgentError, ToolExecutionContext, ToolFunction};
13
14#[cfg(feature = "token_estimation")]
15static TOKENIZER: OnceLock<Result<tokie::Tokenizer, tokie::HubError>> = OnceLock::new();
16
17#[cfg(feature = "token_estimation")]
18fn tokenizer() -> &'static Result<tokie::Tokenizer, tokie::HubError> {
19 TOKENIZER.get_or_init(|| tokie::Tokenizer::from_pretrained("gpt2"))
20}
21
22pub fn check_api_key(api_key: &str) -> Result<(), std::env::VarError> {
26 let _ = std::env::var(api_key)?;
27 Ok(())
28}
29
30pub fn convert_event_to_message(event: AgentEventAny) -> Option<Message> {
35 match event {
36 AgentEventAny::UserPromptSubmit(p) => Some(Message {
37 role: Role::User,
38 content: p.prompt,
39 name: None,
40 tool_calls: None,
41 tool_call_id: None,
42 }),
43 AgentEventAny::AssistantResponse(p) => {
44 let msg = if let Some(tc) = p.tool_calls {
45 let calls: Vec<ToolCall> = tc
46 .iter()
47 .map(|t| ToolCall {
48 call_type: t.call_type.clone(),
49 id: t.id.clone(),
50 function: FunctionCall {
51 name: t.function.name.clone(),
52 arguments: t.function.arguments.clone(),
53 },
54 })
55 .collect();
56 Message {
57 role: Role::Assistant,
58 content: p.full_text,
59 name: None,
60 tool_calls: Some(calls),
61 tool_call_id: None,
62 }
63 } else {
64 Message {
65 role: Role::Assistant,
66 content: p.full_text,
67 name: None,
68 tool_calls: None,
69 tool_call_id: None,
70 }
71 };
72 Some(msg)
73 }
74 AgentEventAny::ToolResult(p) => {
75 let result = match p.result {
76 ToolResult::Ok(r) => format!("Tool call succeeded: {}", r),
77 ToolResult::Err(r) => format!("Tool call failed: {}", r),
78 _ => unreachable!("ToolResult should not reach this branch"),
79 };
80 Some(Message {
81 role: Role::Tool,
82 content: result,
83 name: None,
84 tool_calls: None,
85 tool_call_id: Some(p.tool_call_id),
86 })
87 }
88 _ => None,
89 }
90}
91
92pub enum JsonResult {
94 Valid(Value),
96 Incomplete,
98 Malformed,
100}
101
102pub fn parse_json_fragment(s: &str) -> JsonResult {
107 let v = serde_json::from_str::<Value>(s);
108 match v {
109 Ok(val) => JsonResult::Valid(val),
110 Err(e) => {
111 if e.is_eof() {
112 return JsonResult::Incomplete;
113 }
114 JsonResult::Malformed
115 }
116 }
117}
118
119pub async fn call_tool<Ctx: Send + Sync + 'static>(
125 tool: Arc<dyn ToolFunction<Ctx>>,
126 tool_args: Value,
127 tool_context: Arc<ToolExecutionContext<Ctx>>,
128) -> Result<ToolResult, AgentError> {
129 jsonschema::validate(&tool.input_schema(), &tool_args)
130 .map_err(|e| AgentError::ToolCallError(e.to_string()))?;
131 let result = tool.execute(tool_args, &tool_context).await?;
132 Ok(result)
133}
134
135pub fn estimate_tokens(_text: &str) -> Result<usize, AgentError> {
138 #[cfg(feature = "token_estimation")]
139 {
140 Ok(tokenizer()
141 .as_ref()
142 .map_err(|e| AgentError::TokenizerLoadingError(e.to_string()))?
143 .count_tokens(_text))
144 }
145 #[cfg(not(feature = "token_estimation"))]
146 {
147 Ok(0)
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use chrono::Utc;
155 use microagents_events::{
156 AssistantResponseEvent, SessionInitEvent, SessionInitType, SessionStopEvent,
157 SkillLoadEvent, StreamDeltaEvent, ToolCallEvent, ToolResultEvent, Usage,
158 UserPromptSubmitEvent,
159 types::{FunctionCall as EventFunctionCall, ToolCall as EventToolCall},
160 };
161
162 #[test]
163 fn test_convert_user_prompt_submit() {
164 let event = AgentEventAny::UserPromptSubmit(UserPromptSubmitEvent {
165 session_id: "s1".into(),
166 turn_id: "t1".into(),
167 prompt: "hello".into(),
168 timestamp: Utc::now(),
169 });
170 let msg = convert_event_to_message(event).unwrap();
171 assert_eq!(msg.role, Role::User);
172 assert_eq!(msg.content, "hello");
173 assert!(msg.tool_calls.is_none());
174 assert!(msg.tool_call_id.is_none());
175 }
176
177 #[test]
178 fn test_convert_assistant_response_without_tool_calls() {
179 let event = AgentEventAny::AssistantResponse(AssistantResponseEvent {
180 session_id: "s1".into(),
181 turn_id: "t1".into(),
182 full_text: "hi there".into(),
183 tool_calls: None,
184 timestamp: Utc::now(),
185 });
186 let msg = convert_event_to_message(event).unwrap();
187 assert_eq!(msg.role, Role::Assistant);
188 assert_eq!(msg.content, "hi there");
189 assert!(msg.tool_calls.is_none());
190 }
191
192 #[test]
193 fn test_convert_assistant_response_with_tool_calls() {
194 let event = AgentEventAny::AssistantResponse(AssistantResponseEvent {
195 session_id: "s1".into(),
196 turn_id: "t1".into(),
197 full_text: "calling tool".into(),
198 tool_calls: Some(vec![EventToolCall {
199 id: "tc1".into(),
200 call_type: "function".into(),
201 function: EventFunctionCall {
202 name: "my_tool".into(),
203 arguments: "{\"x\":1}".into(),
204 },
205 }]),
206 timestamp: Utc::now(),
207 });
208 let msg = convert_event_to_message(event).unwrap();
209 assert_eq!(msg.role, Role::Assistant);
210 let calls = msg.tool_calls.unwrap();
211 assert_eq!(calls.len(), 1);
212 assert_eq!(calls[0].id, "tc1");
213 assert_eq!(calls[0].function.name, "my_tool");
214 assert_eq!(calls[0].function.arguments, "{\"x\":1}");
215 }
216
217 #[test]
218 fn test_convert_tool_result_ok() {
219 let event = AgentEventAny::ToolResult(ToolResultEvent {
220 session_id: "s1".into(),
221 turn_id: "t1".into(),
222 result: ToolResult::Ok("done".into()),
223 tool_call_id: "tc1".into(),
224 timestamp: Utc::now(),
225 });
226 let msg = convert_event_to_message(event).unwrap();
227 assert_eq!(msg.role, Role::Tool);
228 assert_eq!(msg.content, "Tool call succeeded: done");
229 assert_eq!(msg.tool_call_id, Some("tc1".into()));
230 }
231
232 #[test]
233 fn test_convert_tool_result_err() {
234 let event = AgentEventAny::ToolResult(ToolResultEvent {
235 session_id: "s1".into(),
236 turn_id: "t1".into(),
237 result: ToolResult::Err("oops".into()),
238 tool_call_id: "tc2".into(),
239 timestamp: Utc::now(),
240 });
241 let msg = convert_event_to_message(event).unwrap();
242 assert_eq!(msg.role, Role::Tool);
243 assert_eq!(msg.content, "Tool call failed: oops");
244 assert_eq!(msg.tool_call_id, Some("tc2".into()));
245 }
246
247 #[test]
248 fn test_convert_other_events_return_none() {
249 assert!(
250 convert_event_to_message(AgentEventAny::SessionInit(SessionInitEvent {
251 session_id: "s1".into(),
252 model: "m".into(),
253 provider: "p".into(),
254 system: "sys".into(),
255 init_type: SessionInitType::Start,
256 timestamp: Utc::now(),
257 }))
258 .is_none()
259 );
260
261 assert!(
262 convert_event_to_message(AgentEventAny::SessionStop(SessionStopEvent {
263 session_id: "s1".into(),
264 success: true,
265 result: None,
266 error: None,
267 timestamp: Utc::now(),
268 usage: Usage::default()
269 }))
270 .is_none()
271 );
272
273 assert!(
274 convert_event_to_message(AgentEventAny::StreamDelta(StreamDeltaEvent {
275 session_id: "s1".into(),
276 turn_id: "t1".into(),
277 delta: "d".into(),
278 delta_type: microagents_events::DeltaType::Text,
279 timestamp: Utc::now(),
280 }))
281 .is_none()
282 );
283
284 assert!(
285 convert_event_to_message(AgentEventAny::ToolCall(ToolCallEvent {
286 session_id: "s1".into(),
287 turn_id: "t1".into(),
288 name: "tool".into(),
289 input: Value::Null,
290 timestamp: Utc::now(),
291 }))
292 .is_none()
293 );
294
295 assert!(
296 convert_event_to_message(AgentEventAny::SkillLoad(SkillLoadEvent {
297 session_id: "s1".into(),
298 turn_id: "t1".into(),
299 skill_name: "skill".into(),
300 timestamp: Utc::now(),
301 }))
302 .is_none()
303 );
304 }
305
306 #[test]
307 fn test_parse_json_fragment_valid() {
308 match parse_json_fragment(r#"{"key": "value"}"#) {
309 JsonResult::Valid(v) => assert_eq!(v["key"], "value"),
310 _ => panic!("expected Valid"),
311 }
312 }
313
314 #[test]
315 fn test_parse_json_fragment_incomplete() {
316 match parse_json_fragment(r#"{"key": "val""#) {
317 JsonResult::Incomplete => {}
318 _ => panic!("expected Incomplete"),
319 }
320 }
321
322 #[test]
323 fn test_parse_json_fragment_malformed() {
324 match parse_json_fragment(r#"{"key": "value",}"#) {
325 JsonResult::Malformed => {}
326 _ => panic!("expected Malformed"),
327 }
328 }
329
330 #[derive(Debug)]
331 struct DummyTool {
332 schema: Value,
333 }
334
335 #[async_trait::async_trait]
336 impl ToolFunction<()> for DummyTool {
337 fn name(&self) -> &'static str {
338 "dummy"
339 }
340 fn description(&self) -> &'static str {
341 "desc"
342 }
343 fn input_schema(&self) -> Value {
344 self.schema.clone()
345 }
346 async fn execute(
347 &self,
348 _input: Value,
349 _ctx: &Arc<ToolExecutionContext<()>>,
350 ) -> Result<ToolResult, AgentError> {
351 Ok(ToolResult::Ok("ok".into()))
352 }
353 }
354
355 #[tokio::test]
356 async fn test_call_tool_validates_and_executes() {
357 let schema = serde_json::json!({
358 "type": "object",
359 "properties": {
360 "name": { "type": "string" }
361 },
362 "required": ["name"]
363 });
364 let tool = Arc::new(DummyTool { schema });
365 let ctx = Arc::new(ToolExecutionContext::new(()));
366 let args = serde_json::json!({"name": "world"});
367 let result = call_tool(tool, args, ctx).await.unwrap();
368 assert!(matches!(result, ToolResult::Ok(ref s) if s == "ok"));
369 }
370
371 #[tokio::test]
372 async fn test_call_tool_schema_validation_fails() {
373 let schema = serde_json::json!({
374 "type": "object",
375 "properties": {
376 "count": { "type": "integer" }
377 },
378 "required": ["count"]
379 });
380 let tool = Arc::new(DummyTool { schema });
381 let ctx = Arc::new(ToolExecutionContext::new(()));
382 let args = serde_json::json!({"count": "not a number"});
383 let err = call_tool(tool, args, ctx).await.unwrap_err();
384 match err {
385 AgentError::ToolCallError(_) => {}
386 other => panic!("expected ToolCallError, got {:?}", other),
387 }
388 }
389
390 #[test]
391 #[cfg(feature = "token_estimation")]
392 fn test_estimate_tokens() {
393 let count = estimate_tokens("hello world").expect("Should be able to estimate tokens");
394 assert_eq!(count, 2);
395 }
396
397 #[test]
398 #[cfg(not(feature = "token_estimation"))]
399 fn test_estimate_tokens() {
400 let count = estimate_tokens("hello world").expect("Should be able to estimate tokens");
401 assert_eq!(count, 0);
402 }
403}