use super::config::MistralV3Config;
use super::model::{
MistralV3Attention, MistralV3ForCausalLM, MistralV3MLP, MistralV3Model, MistralV3RmsNorm,
};
use super::tools::{
MistralV3FunctionCaller, ToolCall, ToolDefinition, ToolParameter, ToolParseError,
};
use std::collections::HashMap;
use trustformers_core::traits::Config;
use trustformers_core::{tensor::Tensor, traits::Layer};
#[test]
fn test_mistral_v3_config_vocab_size() {
let cfg = MistralV3Config::mistral_7b_v0_3();
assert_eq!(cfg.vocab_size, 32768);
}
#[test]
fn test_mistral_v3_config_kv_heads() {
let cfg = MistralV3Config::mistral_7b_v0_3();
assert_eq!(cfg.num_key_value_heads, 8);
assert_eq!(cfg.num_attention_heads, 32);
}
#[test]
fn test_mistral_v3_config_sliding_window() {
let cfg = MistralV3Config::mistral_7b_v0_3();
assert_eq!(cfg.sliding_window, 4096);
}
#[test]
fn test_mistral_v3_config_rope_theta() {
let cfg = MistralV3Config::mistral_7b_v0_3();
assert!((cfg.rope_theta - 1_000_000.0_f64).abs() < 1e-6);
}
#[test]
fn test_mistral_v3_config_validation_ok() {
let cfg = MistralV3Config::small_test();
assert!(cfg.validate().is_ok());
}
#[test]
fn test_mistral_v3_rms_norm_output_shape() {
let norm = MistralV3RmsNorm::new(8, 1e-5).expect("new");
let input = Tensor::from_vec(vec![0.5_f32; 8], &[8]).expect("tensor");
let out = norm.forward(input).expect("forward");
assert_eq!(out.shape().iter().product::<usize>(), 8);
}
#[test]
fn test_mistral_v3_gqa_repeat_kv() {
let cfg = MistralV3Config::small_test(); let attn = MistralV3Attention::new(&cfg).expect("new");
let head_dim = cfg.head_dim();
let kv = Tensor::from_vec(vec![1.0_f32; head_dim], &[head_dim]).expect("tensor");
let expanded = attn.repeat_kv(&kv).expect("repeat_kv");
assert_eq!(
expanded.shape().iter().product::<usize>(),
head_dim * cfg.num_query_groups()
);
}
#[test]
fn test_mistral_v3_sliding_window_effective_window_short_seq() {
let cfg = MistralV3Config::small_test(); let attn = MistralV3Attention::new(&cfg).expect("new");
assert_eq!(attn.effective_window(4), 4);
}
#[test]
fn test_mistral_v3_sliding_window_effective_window_long_seq() {
let cfg = MistralV3Config::small_test(); let attn = MistralV3Attention::new(&cfg).expect("new");
assert_eq!(attn.effective_window(20), cfg.sliding_window);
}
#[test]
fn test_mistral_v3_attention_forward_long_seq() {
let cfg = MistralV3Config::small_test(); let attn = MistralV3Attention::new(&cfg).expect("new");
let seq_len = 16_usize; let input = Tensor::from_vec(
vec![0.1_f32; seq_len * cfg.hidden_size],
&[seq_len, cfg.hidden_size],
)
.expect("tensor");
let out = attn.forward(input).expect("forward");
assert_eq!(
out.shape().iter().product::<usize>(),
seq_len * cfg.hidden_size
);
}
#[test]
fn test_mistral_v3_model_forward_small() {
let cfg = MistralV3Config::small_test();
let model = MistralV3Model::new(cfg.clone()).expect("new");
let out = model.run(vec![1_u32, 2, 3]).expect("run");
assert_eq!(out.shape().iter().product::<usize>(), 3 * cfg.hidden_size);
}
#[test]
fn test_mistral_v3_causal_lm_forward_small() {
let cfg = MistralV3Config::small_test();
let model = MistralV3ForCausalLM::new(cfg.clone()).expect("new");
let logits = model.forward(vec![0_u32, 1, 2]).expect("forward");
assert_eq!(logits.shape().iter().product::<usize>(), 3 * cfg.vocab_size);
}
#[test]
fn test_mistral_v3_mlp_forward_shape() {
let cfg = MistralV3Config::small_test();
let mlp = MistralV3MLP::new(&cfg).expect("new");
let input =
Tensor::from_vec(vec![0.3_f32; cfg.hidden_size], &[1, cfg.hidden_size]).expect("tensor");
let out = mlp.forward(input).expect("forward");
assert_eq!(out.shape().iter().product::<usize>(), cfg.hidden_size);
}
fn make_weather_tool() -> ToolDefinition {
let mut parameters = HashMap::new();
parameters.insert(
"location".to_string(),
ToolParameter {
param_type: "string".to_string(),
description: "City name".to_string(),
required: true,
enum_values: None,
},
);
parameters.insert(
"unit".to_string(),
ToolParameter {
param_type: "string".to_string(),
description: "Temperature unit".to_string(),
required: false,
enum_values: Some(vec!["celsius".to_string(), "fahrenheit".to_string()]),
},
);
ToolDefinition {
name: "get_weather".to_string(),
description: "Get current weather for a location".to_string(),
parameters,
}
}
#[test]
fn test_tool_definition_creation() {
let tool = make_weather_tool();
assert_eq!(tool.name, "get_weather");
assert!(tool.parameters.contains_key("location"));
}
#[test]
fn test_format_tool_prompt_contains_tool_name() {
let cfg = MistralV3Config::small_test();
let model = MistralV3ForCausalLM::new(cfg).expect("new");
let caller = MistralV3FunctionCaller::new(model, vec![make_weather_tool()]);
let prompt = caller.format_tool_prompt();
assert!(
prompt.contains("get_weather"),
"prompt should contain tool name"
);
assert!(
prompt.contains("[AVAILABLE_TOOLS]"),
"prompt should contain [AVAILABLE_TOOLS]"
);
assert!(
prompt.contains("[/AVAILABLE_TOOLS]"),
"prompt should contain [/AVAILABLE_TOOLS]"
);
}
#[test]
fn test_parse_tool_calls_valid() {
let output = r#"Sure! [TOOL_CALLS] [{"name": "get_weather", "arguments": {"location": "Paris"}}] More text."#;
let calls = MistralV3FunctionCaller::parse_tool_calls(output).expect("parse");
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name, "get_weather");
assert_eq!(
calls[0].arguments.get("location").and_then(|v| v.as_str()),
Some("Paris")
);
}
#[test]
fn test_parse_tool_calls_no_marker() {
let output = "This response has no tool calls.";
let result = MistralV3FunctionCaller::parse_tool_calls(output);
assert!(matches!(result, Err(ToolParseError::NoToolCalls)));
}
#[test]
fn test_validate_tool_call_missing_required_param() {
let cfg = MistralV3Config::small_test();
let model = MistralV3ForCausalLM::new(cfg).expect("new");
let caller = MistralV3FunctionCaller::new(model, vec![make_weather_tool()]);
let call = ToolCall {
name: "get_weather".to_string(),
arguments: HashMap::new(),
};
let result = caller.validate_tool_call(&call);
assert!(matches!(result, Err(ToolParseError::MissingParam(_))));
}
#[test]
fn test_validate_tool_call_unknown_tool() {
let cfg = MistralV3Config::small_test();
let model = MistralV3ForCausalLM::new(cfg).expect("new");
let caller = MistralV3FunctionCaller::new(model, vec![make_weather_tool()]);
let call = ToolCall {
name: "unknown_tool".to_string(),
arguments: HashMap::new(),
};
let result = caller.validate_tool_call(&call);
assert!(matches!(result, Err(ToolParseError::UnknownTool(_))));
}
#[test]
fn test_tool_parameter_types() {
let tool = make_weather_tool();
let location_param = tool.parameters.get("location").expect("location param");
assert_eq!(location_param.param_type, "string");
assert!(location_param.required);
let unit_param = tool.parameters.get("unit").expect("unit param");
assert!(!unit_param.required);
assert!(unit_param.enum_values.is_some());
}