1use std::collections::HashMap;
2
3use anyhow::Result;
4use either::Either;
5use indexmap::IndexMap;
6use itertools::Itertools;
7use minijinja::{context, value::Kwargs, Environment, Error, ErrorKind, Value};
8use regex::Regex;
9use serde::{Deserialize, Serialize};
10use tokenizers::Tokenizer;
11use tracing::trace;
12
13use crate::{MessageContent, ModelGenerationDefaults, Tool};
14
15const SUPPORTED_ALTERNATE_EOS: &[&str] = &[
16 "<|im_end|>", "<end_of_turn>", "<|end_of_text|>", "<|end|>", "<|eot_id|>", "<|message|>", "<|start|>", "<|channel|>", ];
25
26const DEFAULT_ENABLE_THINKING: bool = true;
28
29#[allow(dead_code)]
30#[derive(Debug, Deserialize)]
31pub struct AddedTokensDecoder {
32 __type: Option<String>,
33 pub content: String,
34 lstrip: bool,
35 normalized: bool,
36 rstrip: bool,
37 single_word: bool,
38 special: Option<bool>,
39}
40
41fn raise_exception(msg: String) -> Result<String, minijinja::Error> {
42 Err(minijinja::Error::new(ErrorKind::InvalidOperation, msg))
43}
44
45#[derive(Debug, Deserialize)]
46pub struct BeginEndUnkPadTok(
47 #[serde(with = "either::serde_untagged")] pub Either<String, AddedTokensDecoder>,
48);
49
50#[derive(Debug, Deserialize)]
51pub struct ChatTemplateValue(
52 #[serde(with = "either::serde_untagged")] pub Either<String, Vec<HashMap<String, String>>>,
53);
54
55#[allow(dead_code)]
56#[derive(Debug, Deserialize, Default)]
57pub struct ChatTemplate {
59 add_bos_token: Option<bool>,
60 add_eos_token: Option<bool>,
61 added_tokens_decoder: Option<HashMap<String, AddedTokensDecoder>>,
62 additional_special_tokens: Option<Vec<String>>,
63 pub bos_token: Option<BeginEndUnkPadTok>,
64
65 pub chat_template: Option<ChatTemplateValue>,
69 clean_up_tokenization_spaces: Option<bool>,
70 device_map: Option<String>,
71 pub eos_token: Option<BeginEndUnkPadTok>,
72 legacy: Option<bool>,
73 model_max_length: Option<f64>,
74 pub pad_token: Option<BeginEndUnkPadTok>,
75 sp_model_kwargs: Option<HashMap<String, String>>,
76 spaces_between_special_tokens: Option<bool>,
77 tokenizer_class: Option<String>,
78 truncation_size: Option<String>,
79 pub unk_token: Option<BeginEndUnkPadTok>,
80 use_default_system_prompt: Option<bool>,
81}
82
83impl ChatTemplate {
84 pub fn has_chat_template(&self) -> bool {
85 self.chat_template.is_some()
86 }
87
88 pub(crate) fn get_template_contents(&self) -> Vec<String> {
89 match self.chat_template.as_ref() {
90 Some(t) => match &t.0 {
91 Either::Left(s) => vec![s.clone()],
92 Either::Right(vec) => vec.iter().flat_map(|m| m.values().cloned()).collect(),
93 },
94 None => vec![],
95 }
96 }
97
98 pub fn is_harmony_format(&self) -> bool {
100 self.get_template_contents()
101 .iter()
102 .any(|t| crate::reasoning_parsers::harmony::is_harmony_template(t))
103 }
104
105 pub fn uses_think_tags(&self) -> bool {
110 if self.is_harmony_format() {
112 return false;
113 }
114
115 self.get_template_contents()
116 .iter()
117 .any(|t| crate::reasoning_parsers::tag_based::is_think_tag_template(t))
118 }
119
120 pub fn uses_channel_tags(&self) -> bool {
122 self.get_template_contents()
123 .iter()
124 .any(|t| crate::reasoning_parsers::tag_based::is_channel_tag_template(t))
125 }
126
127 pub fn eos_tok(&self) -> Option<String> {
128 match self.eos_token.as_ref()?.0 {
129 Either::Left(ref lit) => Some(lit.clone()),
130 Either::Right(ref added) => Some(added.content.clone()),
131 }
132 }
133
134 pub fn bos_tok(&self) -> Option<String> {
135 match self.bos_token.as_ref()?.0 {
136 Either::Left(ref lit) => Some(lit.clone()),
137 Either::Right(ref added) => Some(added.content.clone()),
138 }
139 }
140
141 pub fn unk_tok(&self) -> Option<String> {
142 match self.unk_token.as_ref()?.0 {
143 Either::Left(ref lit) => Some(lit.clone()),
144 Either::Right(ref added) => Some(added.content.clone()),
145 }
146 }
147}
148
149pub fn calculate_eos_tokens(
150 chat_template: &ChatTemplate,
151 gen_conf: Option<&GenerationConfig>,
152 tokenizer: &Tokenizer,
153) -> Vec<u32> {
154 let mut eos_tok_ids = chat_template.eos_tok().map(|x| vec![x]).unwrap_or_default();
155 let mut bos_tok_ids = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
156
157 let templates = chat_template.get_template_contents();
158
159 for alternate in SUPPORTED_ALTERNATE_EOS {
160 if tokenizer.get_vocab(true).contains_key(*alternate)
161 && templates.iter().any(|t| t.contains(*alternate))
162 {
163 eos_tok_ids.push(alternate.to_string())
164 }
165 }
166
167 if let Some(gen_conf) = gen_conf {
168 if let Some(eos_field) = gen_conf.eos_token_id.as_ref() {
169 let ids = match eos_field {
170 Either::Left(id) => vec![*id],
171 Either::Right(ids) => ids.clone(),
172 };
173 for id in ids {
174 let s = tokenizer
175 .decode(&[id], false)
176 .unwrap_or_else(|_| panic!("Unable to decode id {id})"));
177 if !eos_tok_ids.contains(&s) {
178 eos_tok_ids.push(s);
179 }
180 }
181 }
182
183 if let Some(bos_field) = gen_conf.bos_token_id.as_ref() {
184 let ids = match bos_field {
185 Either::Left(id) => vec![*id],
186 Either::Right(ids) => ids.clone(),
187 };
188 for id in ids {
189 let s = tokenizer
190 .decode(&[id], false)
191 .unwrap_or_else(|_| panic!("Unable to decode id {id})"));
192 if !bos_tok_ids.contains(&s) {
193 bos_tok_ids.push(s);
194 }
195 }
196 }
197 }
198
199 eos_tok_ids = eos_tok_ids.into_iter().dedup().collect::<Vec<_>>();
200 bos_tok_ids = bos_tok_ids.into_iter().dedup().collect::<Vec<_>>();
201
202 let bos_render = bos_tok_ids
203 .iter()
204 .map(|val| format!("{val:?}"))
205 .collect::<Vec<String>>()
206 .join(", ");
207 let eos_render = eos_tok_ids
208 .iter()
209 .map(|val| format!("{val:?}"))
210 .collect::<Vec<String>>()
211 .join(", ");
212
213 trace!(
214 "bos_toks = {bos_render}, eos_toks = {eos_render}, unk_tok = {}",
215 chat_template.unk_tok().unwrap_or("`None`".to_string()),
216 );
217
218 let mut eos_toks = Vec::new();
219 for eos_tok in eos_tok_ids {
220 eos_toks.push(
221 tokenizer
222 .get_vocab(true)
223 .get(&eos_tok)
224 .copied()
225 .unwrap_or_else(|| panic!("Unable to extract `{eos_tok}` EOS token.")),
226 )
227 }
228 eos_toks
229}
230
231#[allow(dead_code)]
232#[derive(Debug, Clone, Deserialize)]
233pub struct GenerationConfig {
234 #[serde(default)]
235 #[serde(with = "either::serde_untagged_optional")]
236 bos_token_id: Option<Either<u32, Vec<u32>>>,
237 #[serde(default)]
238 #[serde(with = "either::serde_untagged_optional")]
239 eos_token_id: Option<Either<u32, Vec<u32>>>,
240 #[serde(default)]
241 do_sample: Option<bool>,
242 #[serde(default)]
243 temperature: Option<f64>,
244 #[serde(default)]
245 top_k: Option<usize>,
246 #[serde(default)]
247 top_p: Option<f64>,
248 #[serde(default)]
249 min_p: Option<f64>,
250 #[serde(default)]
251 repetition_penalty: Option<f32>,
252 #[serde(default)]
253 max_new_tokens: Option<usize>,
254 #[serde(default)]
255 max_length: Option<usize>,
256}
257
258impl GenerationConfig {
259 pub fn generation_defaults(&self) -> Option<ModelGenerationDefaults> {
260 let defaults = ModelGenerationDefaults {
261 do_sample: self.do_sample,
262 temperature: self.temperature,
263 top_k: self.top_k,
264 top_p: self.top_p,
265 min_p: self.min_p,
266 repetition_penalty: self.repetition_penalty,
267 max_new_tokens: self.max_new_tokens,
268 max_length: self.max_length,
269 };
270
271 if defaults.is_empty() {
272 None
273 } else {
274 Some(defaults)
275 }
276 }
277}
278
279fn tojson(value: Value, kwargs: Kwargs) -> Result<Value, Error> {
280 if let Ok(indent) = kwargs.get("indent") {
281 let mut buf = Vec::new();
282 let repeat = b" ".repeat(indent);
283 let formatter = serde_json::ser::PrettyFormatter::with_indent(&repeat);
284 let mut ser = serde_json::Serializer::with_formatter(&mut buf, formatter);
285 value.serialize(&mut ser).unwrap();
286 String::from_utf8(buf).map_err(|err| {
287 Error::new(ErrorKind::BadSerialization, "cannot serialize to JSON").with_source(err)
288 })
289 } else {
290 serde_json::to_string(&value).map_err(|err| {
291 Error::new(ErrorKind::BadSerialization, "cannot serialize to JSON").with_source(err)
292 })
293 }
294 .map_err(|err| {
295 Error::new(ErrorKind::InvalidOperation, "cannot serialize to JSON").with_source(err)
296 })
297 .map(|s| {
298 let mut rv = String::with_capacity(s.len());
300 for c in s.chars() {
301 match c {
302 '<' => rv.push_str("\\u003c"),
303 '>' => rv.push_str("\\u003e"),
304 '&' => rv.push_str("\\u0026"),
305 '\'' => rv.push_str("\\u0027"),
306 _ => rv.push(c),
307 }
308 }
309 Value::from_safe_string(rv)
310 })
311}
312
313fn strftime_now(fmt: String) -> Result<String, minijinja::Error> {
314 let date = chrono::Utc::now();
315 let date_string = date.format(&fmt).to_string();
316 Ok(date_string)
317}
318
319use crate::request::ReasoningEffort;
320
321fn is_gemma4_tool_template(template: &str) -> bool {
323 template.contains("<|tool_call>") && template.contains("<tool_call|>")
324}
325
326fn parse_gemma4_tool_call_arguments(messages: &mut [IndexMap<String, MessageContent>]) {
334 for message in messages.iter_mut() {
335 let is_assistant = message
336 .get("role")
337 .and_then(|v| match v {
338 Either::Left(s) => Some(s.as_str()),
339 _ => None,
340 })
341 .is_some_and(|r| r == "assistant");
342 if !is_assistant {
343 continue;
344 }
345
346 let Some(Either::Right(tool_calls)) = message.get_mut("tool_calls") else {
347 continue;
348 };
349 for tc in tool_calls.iter_mut() {
350 let Some(serde_json::Value::Object(func)) = tc.get_mut("function") else {
352 continue;
353 };
354 if let Some(serde_json::Value::String(json_str)) = func.get("arguments") {
355 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(json_str) {
356 if parsed.is_object() {
357 func.insert("arguments".to_string(), parsed);
358 }
359 }
360 }
361 }
362 }
363}
364
365fn preprocess_gemma4_tool_messages(messages: &mut Vec<IndexMap<String, MessageContent>>) {
378 let mut result: Vec<IndexMap<String, MessageContent>> = Vec::with_capacity(messages.len());
379 let mut i = 0;
380
381 while i < messages.len() {
382 let is_tool = messages[i]
383 .get("role")
384 .and_then(|v| match v {
385 Either::Left(s) => Some(s.as_str()),
386 _ => None,
387 })
388 .is_some_and(|r| r == "tool");
389
390 if !is_tool {
391 let mut msg = std::mem::take(&mut messages[i]);
392
393 let is_assistant = msg
396 .get("role")
397 .and_then(|v| match v {
398 Either::Left(s) => Some(s.as_str()),
399 _ => None,
400 })
401 .is_some_and(|r| r == "assistant");
402 if is_assistant && (msg.contains_key("tool_calls") || !msg.contains_key("content")) {
403 msg.insert("content".to_string(), Either::Left(String::new()));
404 }
405
406 result.push(msg);
407 i += 1;
408 continue;
409 }
410
411 let mut tool_responses: Vec<IndexMap<String, serde_json::Value>> = Vec::new();
413 let mut media_parts: Vec<IndexMap<String, serde_json::Value>> = Vec::new();
414 while i < messages.len() {
415 let is_tool = messages[i]
416 .get("role")
417 .and_then(|v| match v {
418 Either::Left(s) => Some(s.as_str()),
419 _ => None,
420 })
421 .is_some_and(|r| r == "tool");
422 if !is_tool {
423 break;
424 }
425
426 let tool_msg = &messages[i];
427
428 let name = tool_msg
429 .get("name")
430 .and_then(|v| match v {
431 Either::Left(s) => Some(s.clone()),
432 _ => None,
433 })
434 .unwrap_or_else(|| "unknown".to_string());
435
436 let content = match tool_msg.get("content") {
437 Some(Either::Left(s)) => s.clone(),
438 Some(Either::Right(parts)) => {
439 let mut text = String::new();
440 for part in parts {
441 match part.get("type").and_then(|v| v.as_str()) {
442 Some("text") => {
443 if let Some(t) = part.get("text").and_then(|v| v.as_str()) {
444 text.push_str(t);
445 }
446 }
447 Some("image") | Some("audio") | Some("video") => {
448 media_parts.push(part.clone());
449 }
450 _ => {}
451 }
452 }
453 text
454 }
455 _ => String::new(),
456 };
457
458 let response_value: serde_json::Value =
459 serde_json::from_str(&content).unwrap_or(serde_json::Value::String(content));
460
461 let mut entry = IndexMap::new();
462 entry.insert("name".to_string(), serde_json::Value::String(name));
463 entry.insert("response".to_string(), response_value);
464 tool_responses.push(entry);
465
466 i += 1;
467 }
468
469 let mut user_msg: IndexMap<String, MessageContent> = IndexMap::new();
471 user_msg.insert("role".to_string(), Either::Left("user".to_string()));
472 user_msg.insert("tool_responses".to_string(), Either::Right(tool_responses));
473 if !media_parts.is_empty() {
474 user_msg.insert("content".to_string(), Either::Right(media_parts));
475 }
476 result.push(user_msg);
477 }
478
479 *messages = result;
480}
481
482#[allow(clippy::too_many_arguments)]
483pub fn apply_chat_template_to(
484 mut messages: Vec<IndexMap<String, MessageContent>>,
485 add_generation_prompt: bool,
486 enable_thinking: Option<bool>,
487 reasoning_effort: Option<ReasoningEffort>,
488 template: &ChatTemplateValue,
489 bos_tok: Option<String>,
490 eos_tok: Option<String>,
491 unk_tok: Option<String>,
492 tools: Vec<Tool>,
493) -> Result<String> {
494 let mut env = Environment::new();
495
496 env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
498
499 env.set_lstrip_blocks(true);
501 env.set_trim_blocks(true);
502
503 #[derive(Serialize, Deserialize)]
504 struct UntaggedContent(#[serde(with = "either::serde_untagged")] MessageContent);
505
506 let resolved_template = match &template.0 {
508 Either::Left(x) => x.clone(),
509 Either::Right(map) => {
510 let has_tool_use = map.iter().any(|t| {
511 t.get("name").is_some_and(|name| name == "tool_use") || t.contains_key("tool_use")
512 });
513 let must_use_tool_template = !tools.is_empty();
514
515 if must_use_tool_template && !has_tool_use {
516 anyhow::bail!(
517 "Tools were provided but this chat template does not handle tool usage"
518 );
519 }
520
521 let mut found_template = None;
522 for t in map {
523 let name = t.get("name");
524 if let Some(name) = name {
525 found_template = Some(t["template"].clone());
526 #[allow(clippy::if_same_then_else)]
527 if name == "tool_use" && !tools.is_empty() {
528 break;
529 } else if name == "default" && !must_use_tool_template {
530 break;
531 }
532 } else if t.contains_key("tool_use") && !tools.is_empty() {
533 found_template = Some(t["tool_use"].clone());
534 break;
535 } else if t.contains_key("default") && !must_use_tool_template {
536 found_template = Some(t["default"].clone());
537 break;
538 }
539 }
540
541 found_template.ok_or_else(|| anyhow::anyhow!("Chat template does not contain a `tool_use` or `default` key. Please ensure it contains at least a `default` key, although `tool_use` should be specified for using tools."))?
542 }
543 };
544
545 if is_gemma4_tool_template(&resolved_template) {
550 parse_gemma4_tool_call_arguments(&mut messages);
551 preprocess_gemma4_tool_messages(&mut messages);
552 }
553
554 let mut new_messages = Vec::new();
555 for message in messages {
556 let mut new_message = IndexMap::new();
557 for (k, v) in message {
558 new_message.insert(k, UntaggedContent(v));
559 }
560 new_messages.push(new_message);
561 }
562
563 let mut template = resolved_template.replace("[::-1]", "|reverse");
565 let re = Regex::new(r"range\((?P<expr>[^,]+),\s*-1,\s*-1\)").unwrap();
569 template = re
570 .replace_all(&template, |caps: ®ex::Captures| {
571 format!("range({})|reverse", &caps["expr"])
572 })
573 .into_owned();
574
575 if template.contains("{{ meta }}") {
576 template = template.replace("{%- set meta = message.get(\"metadata\", \"\") %}", "");
578 template = template.replace("{{ meta }}", "");
579 }
580 if template.contains("{% generation %}") && template.contains("{% endgeneration %}") {
581 template = template.replace("{% generation %}", "");
583 template = template.replace("{% endgeneration %}", "");
584 }
585
586 env.add_template("chat_template", &template)?;
587 env.add_function("raise_exception", raise_exception);
588 env.add_filter("tojson", tojson);
589 env.add_function("strftime_now", strftime_now);
590 let tmpl = env.get_template("chat_template").unwrap();
591
592 let date = chrono::Utc::now();
593 let date_string = date.format("%d, %B, %Y").to_string();
594
595 let reasoning_effort_str = reasoning_effort.map(|r| r.as_str()).unwrap_or("medium");
597
598 let builtin_tool_names = [
602 "browser",
603 "python",
604 "code_interpreter",
605 "web_search",
606 "brave_search",
607 "wolfram_alpha",
608 ];
609 let builtin_tools: Vec<&str> = tools
610 .iter()
611 .filter_map(|t| {
612 let name = t.function.name.as_str();
613 if builtin_tool_names.contains(&name) {
614 Some(name)
615 } else {
616 None
617 }
618 })
619 .collect();
620
621 let is_gemma4 = is_gemma4_tool_template(&resolved_template);
622
623 let mut rendered = if tools.is_empty() {
624 tmpl.render(context! {
625 messages => new_messages,
626 add_generation_prompt => add_generation_prompt,
627 bos_token => bos_tok,
628 eos_token => eos_tok,
629 unk_token => unk_tok,
630 date_string => date_string,
631 enable_thinking => enable_thinking.unwrap_or(DEFAULT_ENABLE_THINKING),
632 reasoning_effort => reasoning_effort_str,
633 })?
634 } else {
635 tmpl.render(context! {
636 messages => new_messages,
637 add_generation_prompt => add_generation_prompt,
638 bos_token => bos_tok,
639 eos_token => eos_tok,
640 unk_token => unk_tok,
641 xml_tools => tools.clone(), tools => tools,
643 builtin_tools => builtin_tools,
644 date_string => date_string,
645 enable_thinking => enable_thinking.unwrap_or(DEFAULT_ENABLE_THINKING),
646 reasoning_effort => reasoning_effort_str,
647 })?
648 };
649
650 if is_gemma4 && add_generation_prompt && rendered.ends_with("<tool_response|>") {
656 rendered.push_str("<|turn>model\n");
657 }
658
659 Ok(rendered)
660}
661
662#[cfg(test)]
663mod tests {
664 use either::Either;
665 use indexmap::IndexMap;
666 use serde_json::Value;
667
668 use super::{
669 apply_chat_template_to, preprocess_gemma4_tool_messages, ChatTemplateValue,
670 GenerationConfig, DEFAULT_ENABLE_THINKING,
671 };
672 use crate::MessageContent;
673
674 fn user_text_message(text: &str) -> IndexMap<String, MessageContent> {
675 IndexMap::from([
676 ("role".to_string(), Either::Left("user".to_string())),
677 ("content".to_string(), Either::Left(text.to_string())),
678 ])
679 }
680
681 #[test]
682 fn unspecified_thinking_enables_template_thinking() {
683 let template = ChatTemplateValue(Either::Left(
684 "{% if enable_thinking is defined and enable_thinking %}<|think|>{% endif %}{{ bos_token }}{{ messages[0]['content'] }}".to_string(),
685 ));
686 let messages = vec![user_text_message("hello")];
687
688 let rendered = apply_chat_template_to(
689 messages,
690 false,
691 None,
692 None,
693 &template,
694 Some("<bos>".to_string()),
695 None,
696 None,
697 vec![],
698 )
699 .unwrap();
700 let enabled = apply_chat_template_to(
701 vec![user_text_message("hello")],
702 false,
703 Some(true),
704 None,
705 &template,
706 Some("<bos>".to_string()),
707 None,
708 None,
709 vec![],
710 )
711 .unwrap();
712
713 const { assert!(DEFAULT_ENABLE_THINKING) };
714 assert_eq!(rendered, "<|think|><bos>hello");
715 assert_eq!(rendered, enabled);
716 }
717
718 #[test]
719 fn generation_config_exposes_sampling_defaults() {
720 let config: GenerationConfig = serde_json::from_str(
721 r#"{
722 "do_sample": true,
723 "temperature": 1.0,
724 "top_k": 32,
725 "top_p": 0.9,
726 "min_p": 0.05,
727 "repetition_penalty": 1.1,
728 "max_new_tokens": 512
729 }"#,
730 )
731 .unwrap();
732
733 let defaults = config.generation_defaults().unwrap();
734 assert_eq!(defaults.do_sample, Some(true));
735 assert_eq!(defaults.temperature, Some(1.0));
736 assert_eq!(defaults.top_k, Some(32));
737 assert_eq!(defaults.top_p, Some(0.9));
738 assert_eq!(defaults.min_p, Some(0.05));
739 assert_eq!(defaults.repetition_penalty, Some(1.1));
740 assert_eq!(defaults.max_new_tokens, Some(512));
741 }
742
743 fn assistant_message_with_tool_calls() -> IndexMap<String, MessageContent> {
744 let mut tc_map = IndexMap::new();
745 tc_map.insert("id".to_string(), Value::String("call-1".to_string()));
746 tc_map.insert("type".to_string(), Value::String("function".to_string()));
747 let mut func = serde_json::Map::new();
748 func.insert("name".to_string(), Value::String("get_weather".to_string()));
749 func.insert(
750 "arguments".to_string(),
751 Value::String(r#"{"city":"Boston"}"#.to_string()),
752 );
753 tc_map.insert("function".to_string(), Value::Object(func));
754
755 IndexMap::from([
756 ("role".to_string(), Either::Left("assistant".to_string())),
757 (
758 "content".to_string(),
759 Either::Left(
760 r#"{"name":"get_weather","arguments":"{\"city\":\"Boston\"}"}"#.to_string(),
761 ),
762 ),
763 ("tool_calls".to_string(), Either::Right(vec![tc_map])),
764 ])
765 }
766
767 fn tool_result_message(name: &str, content: &str) -> IndexMap<String, MessageContent> {
768 IndexMap::from([
769 ("role".to_string(), Either::Left("tool".to_string())),
770 ("name".to_string(), Either::Left(name.to_string())),
771 ("content".to_string(), Either::Left(content.to_string())),
772 ])
773 }
774
775 #[test]
776 fn gemma4_preprocess_creates_user_msg_for_tool_responses() {
777 let mut messages = vec![
778 user_text_message("What's the weather?"),
779 assistant_message_with_tool_calls(),
780 tool_result_message("get_weather", r#"{"temp":72}"#),
781 ];
782
783 preprocess_gemma4_tool_messages(&mut messages);
784
785 assert_eq!(messages.len(), 3);
787 assert!(!messages[1].contains_key("tool_responses"));
789 let content = messages[1].get("content").unwrap();
791 assert_eq!(content, &Either::Left(String::new()));
792 let role = messages[2].get("role").unwrap();
794 assert_eq!(role, &Either::Left("user".to_string()));
795 assert!(messages[2].contains_key("tool_responses"));
796 }
797
798 #[test]
799 fn gemma4_preprocess_tool_response_has_correct_structure() {
800 let mut messages = vec![
801 user_text_message("hi"),
802 assistant_message_with_tool_calls(),
803 tool_result_message("get_weather", r#"{"temp":72}"#),
804 ];
805
806 preprocess_gemma4_tool_messages(&mut messages);
807
808 let tool_responses = match messages[2].get("tool_responses").unwrap() {
809 Either::Right(v) => v,
810 _ => panic!("Expected Either::Right"),
811 };
812 assert_eq!(tool_responses.len(), 1);
813 assert_eq!(tool_responses[0]["name"], "get_weather");
814 assert_eq!(tool_responses[0]["response"]["temp"], 72);
816 }
817
818 #[test]
819 fn gemma4_parse_tool_call_arguments_converts_json_string_to_object() {
820 let mut messages = vec![
821 user_text_message("call something"),
822 assistant_message_with_tool_calls(),
823 ];
824 if let Some(Either::Right(ref tcs)) = messages[1].get("tool_calls") {
826 let func = tcs[0].get("function").unwrap();
827 assert!(func.get("arguments").unwrap().is_string());
828 }
829
830 super::parse_gemma4_tool_call_arguments(&mut messages);
831
832 if let Some(Either::Right(ref tcs)) = messages[1].get("tool_calls") {
834 let func = tcs[0].get("function").unwrap();
835 let args = func.get("arguments").unwrap();
836 assert!(args.is_object(), "arguments should be parsed to object");
837 assert_eq!(args.get("city").unwrap(), "Boston");
838 } else {
839 panic!("expected tool_calls");
840 }
841 }
842
843 #[test]
844 fn gemma4_preprocess_multiple_tool_messages() {
845 let mut messages = vec![
846 user_text_message("hi"),
847 assistant_message_with_tool_calls(),
848 tool_result_message("get_weather", r#"{"temp":72}"#),
849 tool_result_message("get_forecast", "sunny"),
850 ];
851
852 preprocess_gemma4_tool_messages(&mut messages);
853
854 assert_eq!(messages.len(), 3);
856 let tool_responses = match messages[2].get("tool_responses").unwrap() {
857 Either::Right(v) => v,
858 _ => panic!("Expected Either::Right"),
859 };
860 assert_eq!(tool_responses.len(), 2);
861 assert_eq!(tool_responses[0]["name"], "get_weather");
862 assert_eq!(tool_responses[1]["name"], "get_forecast");
863 assert_eq!(tool_responses[1]["response"], "sunny");
865 }
866
867 #[test]
868 fn gemma4_preprocess_no_tool_messages_is_noop() {
869 let mut messages = vec![
870 user_text_message("hello"),
871 IndexMap::from([
872 ("role".to_string(), Either::Left("assistant".to_string())),
873 ("content".to_string(), Either::Left("hi there".to_string())),
874 ]),
875 ];
876 let original_len = messages.len();
877
878 preprocess_gemma4_tool_messages(&mut messages);
879
880 assert_eq!(messages.len(), original_len);
881 }
882
883 #[test]
884 fn gemma4_preprocess_tool_without_name_defaults_to_unknown() {
885 let mut messages = vec![
886 user_text_message("hi"),
887 assistant_message_with_tool_calls(),
888 IndexMap::from([
890 ("role".to_string(), Either::Left("tool".to_string())),
891 ("content".to_string(), Either::Left("result".to_string())),
892 ]),
893 ];
894
895 preprocess_gemma4_tool_messages(&mut messages);
896
897 let tool_responses = match messages[2].get("tool_responses").unwrap() {
898 Either::Right(v) => v,
899 _ => panic!("Expected Either::Right"),
900 };
901 assert_eq!(tool_responses[0]["name"], "unknown");
902 }
903
904 #[test]
905 fn generation_config_keeps_omitted_sampling_fields_unset() {
906 let config: GenerationConfig = serde_json::from_str(
907 r#"{
908 "do_sample": true,
909 "temperature": 1.0
910 }"#,
911 )
912 .unwrap();
913
914 let defaults = config.generation_defaults().unwrap();
915 assert_eq!(defaults.do_sample, Some(true));
916 assert_eq!(defaults.temperature, Some(1.0));
917 assert_eq!(defaults.top_k, None);
918 assert_eq!(defaults.top_p, None);
919 assert_eq!(defaults.repetition_penalty, None);
920 assert_eq!(defaults.max_new_tokens, None);
921 assert_eq!(defaults.max_length, None);
922 }
923}