use std::time::Duration;
use super::*;
use futures::future::join_all;
use siumai_core::error::LlmError as SiumaiLlmError;
use siumai_core::streaming::ChatStream;
use siumai_core::types::{ChatMessage as SiumaiChatMessage, ChatResponse, EmbeddingResponse, Tool};
struct StubChat;
#[async_trait::async_trait]
impl ChatCapability for StubChat {
async fn chat_with_tools(
&self,
_messages: Vec<SiumaiChatMessage>,
_tools: Option<Vec<Tool>>,
) -> Result<ChatResponse, SiumaiLlmError> {
unreachable!("StubChat should not be called in OnceCell dedup test")
}
async fn chat_stream(
&self,
_messages: Vec<SiumaiChatMessage>,
_tools: Option<Vec<Tool>>,
) -> Result<ChatStream, SiumaiLlmError> {
unreachable!("StubChat should not be called in OnceCell dedup test")
}
}
struct StubEmbed;
#[async_trait::async_trait]
impl EmbeddingCapability for StubEmbed {
async fn embed(&self, _input: Vec<String>) -> Result<EmbeddingResponse, SiumaiLlmError> {
Ok(EmbeddingResponse::new(
vec![vec![1.0, 2.0, 3.0]],
"stub-model".into(),
))
}
fn as_embedding_extensions(&self) -> Option<&dyn EmbeddingExtensions> {
Some(self)
}
fn embedding_dimension(&self) -> usize {
3
}
}
#[async_trait::async_trait]
impl EmbeddingExtensions for StubEmbed {}
#[test]
fn map_tool_definition_function_name_and_description() {
let params = serde_json::json!({
"type": "object",
"properties": {
"location": { "type": "string" }
}
});
let def = ToolDefinition {
name: "get_weather".into(),
description: "Get weather for a city".into(),
parameters: params.as_object().unwrap().clone(),
};
let tool = map_tool_definition(def);
let func = tool.function_ref().expect("expected function tool");
assert_eq!(func.name, "get_weather");
assert_eq!(func.description, "Get weather for a city");
assert_eq!(
func.parameters,
serde_json::json!({
"type": "object",
"properties": {
"location": { "type": "string" }
}
})
);
}
#[test]
fn map_tool_choice_auto() {
let siumai = map_tool_choice(ToolChoice::Auto);
assert_eq!(siumai, SiumaiToolChoice::Auto);
}
#[test]
fn map_tool_choice_none() {
let siumai = map_tool_choice(ToolChoice::None);
assert_eq!(siumai, SiumaiToolChoice::None);
}
#[test]
fn map_tool_choice_specific() {
let siumai = map_tool_choice(ToolChoice::Specific("get_weather".into()));
assert_eq!(siumai, SiumaiToolChoice::tool("get_weather"));
}
#[test]
fn convert_stream_part_tool_call() {
use siumai_core::types::ChatStreamPart;
use siumai_core::types::ChatStreamToolCall;
let part = ChatStreamPart::ToolCall(ChatStreamToolCall {
tool_call_id: "call_123".into(),
tool_name: "get_weather".into(),
input: r#"{"city":"London"}"#.into(),
provider_executed: None,
dynamic: None,
provider_metadata: None,
});
let event = convert_stream_part(part).expect("expected Some");
if let ChatEvent::ToolCall {
id,
name,
arguments,
} = event
{
assert_eq!(id, "call_123");
assert_eq!(name, "get_weather");
assert_eq!(arguments, r#"{"city":"London"}"#);
} else {
panic!("expected ToolCall");
}
}
#[test]
fn extract_tool_call_delta_accumulation() {
use siumai_core::streaming::ChatStreamEvent;
use siumai_core::types::ChatStreamPart;
let mut buffers: HashMap<String, (String, String)> = HashMap::new();
let start = ChatStreamEvent::Part {
part: ChatStreamPart::ToolInputStart {
id: "call_1".into(),
tool_name: "get_weather".into(),
provider_metadata: None,
provider_executed: None,
dynamic: None,
title: None,
},
};
assert!(
extract_tool_call_event(&start, &mut buffers).is_none(),
"ToolInputStart should not emit ToolCall"
);
assert!(
buffers.contains_key("call_1"),
"buffer should be initialised"
);
let delta1 = ChatStreamEvent::Part {
part: ChatStreamPart::ToolInputDelta {
id: "call_1".into(),
delta: r#"{"city":"L"#.into(),
provider_metadata: None,
},
};
assert!(extract_tool_call_event(&delta1, &mut buffers).is_none());
let delta2 = ChatStreamEvent::Part {
part: ChatStreamPart::ToolInputDelta {
id: "call_1".into(),
delta: "ondon\"}".into(),
provider_metadata: None,
},
};
assert!(extract_tool_call_event(&delta2, &mut buffers).is_none());
let end = ChatStreamEvent::Part {
part: ChatStreamPart::ToolInputEnd {
id: "call_1".into(),
provider_metadata: None,
},
};
let result = extract_tool_call_event(&end, &mut buffers);
let tc = result.expect("end should emit ToolCall");
if let ChatEvent::ToolCall {
id,
name,
arguments,
} = tc
{
assert_eq!(id, "call_1");
assert_eq!(name, "get_weather");
assert_eq!(arguments, r#"{"city":"London"}"#);
} else {
panic!("expected ToolCall");
}
assert!(!buffers.contains_key("call_1"));
}
#[test]
fn convert_chat_message_tool_role() {
let msg = ChatMessage {
role: ChatRole::Tool {
tool_call_id: "call_99".into(),
},
content: "weather: sunny, 22°C".into(),
tool_calls: None,
};
let siumai = convert_chat_message(msg);
assert_eq!(siumai.role, siumai_core::types::MessageRole::Tool);
let results = siumai.tool_results();
assert_eq!(results.len(), 1, "expected one tool result");
let info = results[0].as_tool_result().expect("tool result info");
assert_eq!(info.tool_call_id, "call_99");
let output_dbg = format!("{:#?}", info.output);
assert!(
output_dbg.contains("sunny"),
"tool result output: {output_dbg}"
);
assert!(
output_dbg.contains("22"),
"tool result output: {output_dbg}"
);
}
#[test]
fn convert_chat_message_assistant_forwards_tool_calls() {
let msg = ChatMessage {
role: ChatRole::Assistant,
content: "Let me check the weather.".into(),
tool_calls: Some(vec![EmittedToolCall {
id: "call_42".into(),
name: "get_weather".into(),
arguments: r#"{"city":"London"}"#.into(),
}]),
};
let siumai = convert_chat_message(msg);
assert_eq!(siumai.role, siumai_core::types::MessageRole::Assistant);
let tool_calls = siumai.tool_calls();
assert_eq!(tool_calls.len(), 1, "expected one tool call");
let info = tool_calls[0].as_tool_call().expect("tool call info");
assert_eq!(info.tool_call_id, "call_42");
assert_eq!(info.tool_name, "get_weather");
assert_eq!(info.arguments, &serde_json::json!({"city":"London"}));
assert_eq!(siumai.content_text(), Some("Let me check the weather."));
}
#[test]
fn convert_chat_message_assistant_no_tool_calls_stays_text() {
let msg = ChatMessage {
role: ChatRole::Assistant,
content: "Just a plain response.".into(),
tool_calls: None,
};
let siumai = convert_chat_message(msg);
assert_eq!(siumai.role, siumai_core::types::MessageRole::Assistant);
assert!(siumai.tool_calls().is_empty(), "no tool calls expected");
assert_eq!(siumai.content_text(), Some("Just a plain response."));
}
#[test]
fn convert_chat_message_assistant_empty_tool_calls_stays_text() {
let msg = ChatMessage {
role: ChatRole::Assistant,
content: "Empty array.".into(),
tool_calls: Some(vec![]),
};
let siumai = convert_chat_message(msg);
assert_eq!(siumai.role, siumai_core::types::MessageRole::Assistant);
assert!(siumai.tool_calls().is_empty(), "no tool calls expected");
assert_eq!(siumai.content_text(), Some("Empty array."));
}
#[test]
fn chat_request_with_tools_and_tool_choice() {
let params = serde_json::json!({
"type": "object",
"properties": {
"city": { "type": "string" }
}
});
let tool = ToolDefinition {
name: "get_weather".into(),
description: "Get weather".into(),
parameters: params.as_object().unwrap().clone(),
};
let req = ChatRequest {
model: "gpt-4o".into(),
messages: vec![ChatMessage::user("weather?")],
temperature: None,
max_tokens: None,
stop: None,
system_prompt: None,
extra: serde_json::Map::new(),
tools: vec![tool],
tool_choice: Some(ToolChoice::Specific("get_weather".into())),
};
let siumai_req = convert_chat_request(req, "gpt-4o");
let tools = siumai_req.tools.expect("tools should be Some");
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].function_ref().unwrap().name, "get_weather");
let tc = siumai_req.tool_choice.expect("tool_choice should be Some");
assert_eq!(tc, SiumaiToolChoice::tool("get_weather"));
}
#[tokio::test]
async fn client_built_once_under_concurrency() {
let stub_builder: BuildOverride = Arc::new(|| {
Box::pin(async {
Ok::<_, LlmError>((
Arc::new(StubChat) as Arc<dyn ChatCapability>,
Arc::new(StubEmbed) as Arc<dyn EmbeddingCapability>,
))
})
});
let provider = Arc::new(SiumaiProvider::new_with_build_override(
"test",
"gpt-4",
SiumaiConfig::OpenAi {
api_key: "sk-test".into(),
base_url: Some("http://127.0.0.1:1".into()),
model: "gpt-4".into(),
},
Duration::from_secs(30),
stub_builder,
));
let builds = Arc::clone(&provider.client_builds);
let mut futs = vec![];
for _ in 0..8 {
let p = Arc::clone(&provider);
futs.push(async move {
let _ = p
.embed(EmbedRequest::new("gpt-4", vec!["hello".into()]))
.await;
});
}
join_all(futs).await;
assert_eq!(
builds.load(Ordering::SeqCst),
1,
"OnceCell must construct the client exactly once across concurrent embed() calls"
);
}
#[test]
fn timeout_maps_to_llm_error_timeout_not_network() {
let err = map_siumai_error(
SiumaiLlmError::TimeoutError("elapsed".into()),
Duration::from_secs(10),
);
assert!(matches!(err, LlmError::Timeout(d) if d == Duration::from_secs(10)));
}
#[test]
fn other_errors_still_map_correctly() {
let err = map_siumai_error(
SiumaiLlmError::ConnectionError("conn reset".into()),
Duration::from_secs(30),
);
assert!(matches!(err, LlmError::Network(_)));
}