Skip to main content

hanzo_engine/pipeline/
chat_template.rs

1use std::collections::HashMap;
2
3use anyhow::Result;
4use either::Either;
5use indexmap::IndexMap;
6use itertools::Itertools;
7use minijinja::{context, value::Kwargs, Environment, Error, ErrorKind, Value};
8use regex::Regex;
9use serde::{Deserialize, Serialize};
10use tokenizers::Tokenizer;
11use tracing::trace;
12
13use crate::{MessageContent, ModelGenerationDefaults, Tool};
14
15const SUPPORTED_ALTERNATE_EOS: &[&str] = &[
16    "<|im_end|>",      // Handle ChatML case
17    "<end_of_turn>",   // Handle Gemma2 chat case
18    "<|end_of_text|>", // Hermes
19    "<|end|>",         // Phi-3, Phi-3.5, Harmony
20    "<|eot_id|>",      // Llama 3
21    "<|message|>",     // Harmony
22    "<|start|>",       // Harmony
23    "<|channel|>",     // Harmony
24];
25
26/// Repository default for templates that support an explicit thinking toggle.
27const DEFAULT_ENABLE_THINKING: bool = true;
28
29#[allow(dead_code)]
30#[derive(Debug, Deserialize)]
31pub struct AddedTokensDecoder {
32    __type: Option<String>,
33    pub content: String,
34    lstrip: bool,
35    normalized: bool,
36    rstrip: bool,
37    single_word: bool,
38    special: Option<bool>,
39}
40
41fn raise_exception(msg: String) -> Result<String, minijinja::Error> {
42    Err(minijinja::Error::new(ErrorKind::InvalidOperation, msg))
43}
44
45#[derive(Debug, Deserialize)]
46pub struct BeginEndUnkPadTok(
47    #[serde(with = "either::serde_untagged")] pub Either<String, AddedTokensDecoder>,
48);
49
50#[derive(Debug, Deserialize)]
51pub struct ChatTemplateValue(
52    #[serde(with = "either::serde_untagged")] pub Either<String, Vec<HashMap<String, String>>>,
53);
54
55#[allow(dead_code)]
56#[derive(Debug, Deserialize, Default)]
57/// Template for chat models including bos/eos/unk as well as the chat template.
58pub struct ChatTemplate {
59    add_bos_token: Option<bool>,
60    add_eos_token: Option<bool>,
61    added_tokens_decoder: Option<HashMap<String, AddedTokensDecoder>>,
62    additional_special_tokens: Option<Vec<String>>,
63    pub bos_token: Option<BeginEndUnkPadTok>,
64
65    /// Jinja format [chat templating] for chat completion.
66    ///
67    /// [chat templating]: https://huggingface.co/docs/transformers/chat_templating
68    pub chat_template: Option<ChatTemplateValue>,
69    clean_up_tokenization_spaces: Option<bool>,
70    device_map: Option<String>,
71    pub eos_token: Option<BeginEndUnkPadTok>,
72    legacy: Option<bool>,
73    model_max_length: Option<f64>,
74    pub pad_token: Option<BeginEndUnkPadTok>,
75    sp_model_kwargs: Option<HashMap<String, String>>,
76    spaces_between_special_tokens: Option<bool>,
77    tokenizer_class: Option<String>,
78    truncation_size: Option<String>,
79    pub unk_token: Option<BeginEndUnkPadTok>,
80    use_default_system_prompt: Option<bool>,
81}
82
83impl ChatTemplate {
84    pub fn has_chat_template(&self) -> bool {
85        self.chat_template.is_some()
86    }
87
88    pub(crate) fn get_template_contents(&self) -> Vec<String> {
89        match self.chat_template.as_ref() {
90            Some(t) => match &t.0 {
91                Either::Left(s) => vec![s.clone()],
92                Either::Right(vec) => vec.iter().flat_map(|m| m.values().cloned()).collect(),
93            },
94            None => vec![],
95        }
96    }
97
98    /// Check if this chat template uses OpenAI Harmony format.
99    pub fn is_harmony_format(&self) -> bool {
100        self.get_template_contents()
101            .iter()
102            .any(|t| crate::reasoning_parsers::harmony::is_harmony_template(t))
103    }
104
105    /// Check if this chat template uses `<think>...</think>` tags for reasoning.
106    ///
107    /// This is mutually exclusive with Harmony format - if the template uses
108    /// Harmony format, this returns false even if think tags are present.
109    pub fn uses_think_tags(&self) -> bool {
110        // Don't enable if Harmony format is detected (mutual exclusivity)
111        if self.is_harmony_format() {
112            return false;
113        }
114
115        self.get_template_contents()
116            .iter()
117            .any(|t| crate::reasoning_parsers::tag_based::is_think_tag_template(t))
118    }
119
120    /// Check if the template uses Gemma 4 channel-based reasoning tags.
121    pub fn uses_channel_tags(&self) -> bool {
122        self.get_template_contents()
123            .iter()
124            .any(|t| crate::reasoning_parsers::tag_based::is_channel_tag_template(t))
125    }
126
127    pub fn eos_tok(&self) -> Option<String> {
128        match self.eos_token.as_ref()?.0 {
129            Either::Left(ref lit) => Some(lit.clone()),
130            Either::Right(ref added) => Some(added.content.clone()),
131        }
132    }
133
134    pub fn bos_tok(&self) -> Option<String> {
135        match self.bos_token.as_ref()?.0 {
136            Either::Left(ref lit) => Some(lit.clone()),
137            Either::Right(ref added) => Some(added.content.clone()),
138        }
139    }
140
141    pub fn unk_tok(&self) -> Option<String> {
142        match self.unk_token.as_ref()?.0 {
143            Either::Left(ref lit) => Some(lit.clone()),
144            Either::Right(ref added) => Some(added.content.clone()),
145        }
146    }
147}
148
149pub fn calculate_eos_tokens(
150    chat_template: &ChatTemplate,
151    gen_conf: Option<&GenerationConfig>,
152    tokenizer: &Tokenizer,
153) -> Vec<u32> {
154    let mut eos_tok_ids = chat_template.eos_tok().map(|x| vec![x]).unwrap_or_default();
155    let mut bos_tok_ids = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
156
157    let templates = chat_template.get_template_contents();
158
159    for alternate in SUPPORTED_ALTERNATE_EOS {
160        if tokenizer.get_vocab(true).contains_key(*alternate)
161            && templates.iter().any(|t| t.contains(*alternate))
162        {
163            eos_tok_ids.push(alternate.to_string())
164        }
165    }
166
167    if let Some(gen_conf) = gen_conf {
168        if let Some(eos_field) = gen_conf.eos_token_id.as_ref() {
169            let ids = match eos_field {
170                Either::Left(id) => vec![*id],
171                Either::Right(ids) => ids.clone(),
172            };
173            for id in ids {
174                let s = tokenizer
175                    .decode(&[id], false)
176                    .unwrap_or_else(|_| panic!("Unable to decode id {id})"));
177                if !eos_tok_ids.contains(&s) {
178                    eos_tok_ids.push(s);
179                }
180            }
181        }
182
183        if let Some(bos_field) = gen_conf.bos_token_id.as_ref() {
184            let ids = match bos_field {
185                Either::Left(id) => vec![*id],
186                Either::Right(ids) => ids.clone(),
187            };
188            for id in ids {
189                let s = tokenizer
190                    .decode(&[id], false)
191                    .unwrap_or_else(|_| panic!("Unable to decode id {id})"));
192                if !bos_tok_ids.contains(&s) {
193                    bos_tok_ids.push(s);
194                }
195            }
196        }
197    }
198
199    eos_tok_ids = eos_tok_ids.into_iter().dedup().collect::<Vec<_>>();
200    bos_tok_ids = bos_tok_ids.into_iter().dedup().collect::<Vec<_>>();
201
202    let bos_render = bos_tok_ids
203        .iter()
204        .map(|val| format!("{val:?}"))
205        .collect::<Vec<String>>()
206        .join(", ");
207    let eos_render = eos_tok_ids
208        .iter()
209        .map(|val| format!("{val:?}"))
210        .collect::<Vec<String>>()
211        .join(", ");
212
213    trace!(
214        "bos_toks = {bos_render}, eos_toks = {eos_render}, unk_tok = {}",
215        chat_template.unk_tok().unwrap_or("`None`".to_string()),
216    );
217
218    let mut eos_toks = Vec::new();
219    for eos_tok in eos_tok_ids {
220        eos_toks.push(
221            tokenizer
222                .get_vocab(true)
223                .get(&eos_tok)
224                .copied()
225                .unwrap_or_else(|| panic!("Unable to extract `{eos_tok}` EOS token.")),
226        )
227    }
228    eos_toks
229}
230
231#[allow(dead_code)]
232#[derive(Debug, Clone, Deserialize)]
233pub struct GenerationConfig {
234    #[serde(default)]
235    #[serde(with = "either::serde_untagged_optional")]
236    bos_token_id: Option<Either<u32, Vec<u32>>>,
237    #[serde(default)]
238    #[serde(with = "either::serde_untagged_optional")]
239    eos_token_id: Option<Either<u32, Vec<u32>>>,
240    #[serde(default)]
241    do_sample: Option<bool>,
242    #[serde(default)]
243    temperature: Option<f64>,
244    #[serde(default)]
245    top_k: Option<usize>,
246    #[serde(default)]
247    top_p: Option<f64>,
248    #[serde(default)]
249    min_p: Option<f64>,
250    #[serde(default)]
251    repetition_penalty: Option<f32>,
252    #[serde(default)]
253    max_new_tokens: Option<usize>,
254    #[serde(default)]
255    max_length: Option<usize>,
256}
257
258impl GenerationConfig {
259    pub fn generation_defaults(&self) -> Option<ModelGenerationDefaults> {
260        let defaults = ModelGenerationDefaults {
261            do_sample: self.do_sample,
262            temperature: self.temperature,
263            top_k: self.top_k,
264            top_p: self.top_p,
265            min_p: self.min_p,
266            repetition_penalty: self.repetition_penalty,
267            max_new_tokens: self.max_new_tokens,
268            max_length: self.max_length,
269        };
270
271        if defaults.is_empty() {
272            None
273        } else {
274            Some(defaults)
275        }
276    }
277}
278
279fn tojson(value: Value, kwargs: Kwargs) -> Result<Value, Error> {
280    if let Ok(indent) = kwargs.get("indent") {
281        let mut buf = Vec::new();
282        let repeat = b" ".repeat(indent);
283        let formatter = serde_json::ser::PrettyFormatter::with_indent(&repeat);
284        let mut ser = serde_json::Serializer::with_formatter(&mut buf, formatter);
285        value.serialize(&mut ser).unwrap();
286        String::from_utf8(buf).map_err(|err| {
287            Error::new(ErrorKind::BadSerialization, "cannot serialize to JSON").with_source(err)
288        })
289    } else {
290        serde_json::to_string(&value).map_err(|err| {
291            Error::new(ErrorKind::BadSerialization, "cannot serialize to JSON").with_source(err)
292        })
293    }
294    .map_err(|err| {
295        Error::new(ErrorKind::InvalidOperation, "cannot serialize to JSON").with_source(err)
296    })
297    .map(|s| {
298        // When this filter is used the return value is safe for both HTML and JSON
299        let mut rv = String::with_capacity(s.len());
300        for c in s.chars() {
301            match c {
302                '<' => rv.push_str("\\u003c"),
303                '>' => rv.push_str("\\u003e"),
304                '&' => rv.push_str("\\u0026"),
305                '\'' => rv.push_str("\\u0027"),
306                _ => rv.push(c),
307            }
308        }
309        Value::from_safe_string(rv)
310    })
311}
312
313fn strftime_now(fmt: String) -> Result<String, minijinja::Error> {
314    let date = chrono::Utc::now();
315    let date_string = date.format(&fmt).to_string();
316    Ok(date_string)
317}
318
319use crate::request::ReasoningEffort;
320
321/// Check if a chat template uses Gemma 4 tool call tokens.
322fn is_gemma4_tool_template(template: &str) -> bool {
323    template.contains("<|tool_call>") && template.contains("<tool_call|>")
324}
325
326/// Parse tool_call `arguments` fields from JSON strings into objects.
327///
328/// The OpenAI API returns `tool_calls[i].function.arguments` as a JSON string,
329/// but the Gemma 4 chat template's `format_argument` macro only emits the
330/// correct `<|"|>` delimited format when `arguments` is a mapping (object).
331/// When it's a string, the raw JSON is rendered verbatim, producing a format
332/// mismatch that confuses the model in multi-turn tool-calling conversations.
333fn parse_gemma4_tool_call_arguments(messages: &mut [IndexMap<String, MessageContent>]) {
334    for message in messages.iter_mut() {
335        let is_assistant = message
336            .get("role")
337            .and_then(|v| match v {
338                Either::Left(s) => Some(s.as_str()),
339                _ => None,
340            })
341            .is_some_and(|r| r == "assistant");
342        if !is_assistant {
343            continue;
344        }
345
346        let Some(Either::Right(tool_calls)) = message.get_mut("tool_calls") else {
347            continue;
348        };
349        for tc in tool_calls.iter_mut() {
350            // tool_calls[i].function.arguments
351            let Some(serde_json::Value::Object(func)) = tc.get_mut("function") else {
352                continue;
353            };
354            if let Some(serde_json::Value::String(json_str)) = func.get("arguments") {
355                if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(json_str) {
356                    if parsed.is_object() {
357                        func.insert("arguments".to_string(), parsed);
358                    }
359                }
360            }
361        }
362    }
363}
364
365/// Pre-process messages for Gemma 4 tool templates.
366///
367/// The Gemma 4 chat template expects `tool_responses` as a field on a
368/// **user** message, but the OpenAI API sends `role: "tool"` as separate
369/// messages. This function replaces consecutive `role: "tool"` messages
370/// with a single `role: "user"` message carrying the `tool_responses`
371/// field, matching the format used by the reference implementations
372/// (llama.cpp `convert_tool_responses_gemma4`, HF transformers).
373///
374/// Additionally, when the preceding assistant message has structured
375/// `tool_calls`, its raw-JSON `content` is cleared so the template only
376/// renders the `<|tool_call>` tags.
377fn preprocess_gemma4_tool_messages(messages: &mut Vec<IndexMap<String, MessageContent>>) {
378    let mut result: Vec<IndexMap<String, MessageContent>> = Vec::with_capacity(messages.len());
379    let mut i = 0;
380
381    while i < messages.len() {
382        let is_tool = messages[i]
383            .get("role")
384            .and_then(|v| match v {
385                Either::Left(s) => Some(s.as_str()),
386                _ => None,
387            })
388            .is_some_and(|r| r == "tool");
389
390        if !is_tool {
391            let mut msg = std::mem::take(&mut messages[i]);
392
393            // When an assistant message has structured tool_calls, clear the
394            // raw-JSON content so the template only renders <|tool_call> tags.
395            let is_assistant = msg
396                .get("role")
397                .and_then(|v| match v {
398                    Either::Left(s) => Some(s.as_str()),
399                    _ => None,
400                })
401                .is_some_and(|r| r == "assistant");
402            if is_assistant && (msg.contains_key("tool_calls") || !msg.contains_key("content")) {
403                msg.insert("content".to_string(), Either::Left(String::new()));
404            }
405
406            result.push(msg);
407            i += 1;
408            continue;
409        }
410
411        // Collect consecutive tool messages into a single tool_responses list.
412        let mut tool_responses: Vec<IndexMap<String, serde_json::Value>> = Vec::new();
413        let mut media_parts: Vec<IndexMap<String, serde_json::Value>> = Vec::new();
414        while i < messages.len() {
415            let is_tool = messages[i]
416                .get("role")
417                .and_then(|v| match v {
418                    Either::Left(s) => Some(s.as_str()),
419                    _ => None,
420                })
421                .is_some_and(|r| r == "tool");
422            if !is_tool {
423                break;
424            }
425
426            let tool_msg = &messages[i];
427
428            let name = tool_msg
429                .get("name")
430                .and_then(|v| match v {
431                    Either::Left(s) => Some(s.clone()),
432                    _ => None,
433                })
434                .unwrap_or_else(|| "unknown".to_string());
435
436            let content = match tool_msg.get("content") {
437                Some(Either::Left(s)) => s.clone(),
438                Some(Either::Right(parts)) => {
439                    let mut text = String::new();
440                    for part in parts {
441                        match part.get("type").and_then(|v| v.as_str()) {
442                            Some("text") => {
443                                if let Some(t) = part.get("text").and_then(|v| v.as_str()) {
444                                    text.push_str(t);
445                                }
446                            }
447                            Some("image") | Some("audio") | Some("video") => {
448                                media_parts.push(part.clone());
449                            }
450                            _ => {}
451                        }
452                    }
453                    text
454                }
455                _ => String::new(),
456            };
457
458            let response_value: serde_json::Value =
459                serde_json::from_str(&content).unwrap_or(serde_json::Value::String(content));
460
461            let mut entry = IndexMap::new();
462            entry.insert("name".to_string(), serde_json::Value::String(name));
463            entry.insert("response".to_string(), response_value);
464            tool_responses.push(entry);
465
466            i += 1;
467        }
468
469        // Create a user message with the collected tool_responses.
470        let mut user_msg: IndexMap<String, MessageContent> = IndexMap::new();
471        user_msg.insert("role".to_string(), Either::Left("user".to_string()));
472        user_msg.insert("tool_responses".to_string(), Either::Right(tool_responses));
473        if !media_parts.is_empty() {
474            user_msg.insert("content".to_string(), Either::Right(media_parts));
475        }
476        result.push(user_msg);
477    }
478
479    *messages = result;
480}
481
482#[allow(clippy::too_many_arguments)]
483pub fn apply_chat_template_to(
484    mut messages: Vec<IndexMap<String, MessageContent>>,
485    add_generation_prompt: bool,
486    enable_thinking: Option<bool>,
487    reasoning_effort: Option<ReasoningEffort>,
488    template: &ChatTemplateValue,
489    bos_tok: Option<String>,
490    eos_tok: Option<String>,
491    unk_tok: Option<String>,
492    tools: Vec<Tool>,
493) -> Result<String> {
494    let mut env = Environment::new();
495
496    // enable python methods such as .strip()
497    env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
498
499    // https://github.com/huggingface/transformers/blob/76a33a10923ccc1074917f6b6a1e719e626b7dc9/src/transformers/tokenization_utils_base.py#L1842
500    env.set_lstrip_blocks(true);
501    env.set_trim_blocks(true);
502
503    #[derive(Serialize, Deserialize)]
504    struct UntaggedContent(#[serde(with = "either::serde_untagged")] MessageContent);
505
506    // Resolve template string early so we can check for Gemma 4 format
507    let resolved_template = match &template.0 {
508        Either::Left(x) => x.clone(),
509        Either::Right(map) => {
510            let has_tool_use = map.iter().any(|t| {
511                t.get("name").is_some_and(|name| name == "tool_use") || t.contains_key("tool_use")
512            });
513            let must_use_tool_template = !tools.is_empty();
514
515            if must_use_tool_template && !has_tool_use {
516                anyhow::bail!(
517                    "Tools were provided but this chat template does not handle tool usage"
518                );
519            }
520
521            let mut found_template = None;
522            for t in map {
523                let name = t.get("name");
524                if let Some(name) = name {
525                    found_template = Some(t["template"].clone());
526                    #[allow(clippy::if_same_then_else)]
527                    if name == "tool_use" && !tools.is_empty() {
528                        break;
529                    } else if name == "default" && !must_use_tool_template {
530                        break;
531                    }
532                } else if t.contains_key("tool_use") && !tools.is_empty() {
533                    found_template = Some(t["tool_use"].clone());
534                    break;
535                } else if t.contains_key("default") && !must_use_tool_template {
536                    found_template = Some(t["default"].clone());
537                    break;
538                }
539            }
540
541            found_template.ok_or_else(|| anyhow::anyhow!("Chat template does not contain a `tool_use` or `default` key. Please ensure it contains at least a `default` key, although `tool_use` should be specified for using tools."))?
542        }
543    };
544
545    // Pre-process messages for Gemma 4 tool templates: parse JSON-string
546    // tool_call arguments into objects (so the template renders them in
547    // <|"|> format), and merge role:"tool" messages into tool_responses on
548    // the preceding assistant message.
549    if is_gemma4_tool_template(&resolved_template) {
550        parse_gemma4_tool_call_arguments(&mut messages);
551        preprocess_gemma4_tool_messages(&mut messages);
552    }
553
554    let mut new_messages = Vec::new();
555    for message in messages {
556        let mut new_message = IndexMap::new();
557        for (k, v) in message {
558            new_message.insert(k, UntaggedContent(v));
559        }
560        new_messages.push(new_message);
561    }
562
563    // Use the already-resolved template string
564    let mut template = resolved_template.replace("[::-1]", "|reverse");
565    // Convert Python‑style descending ranges `range(..., -1, -1)` to a forward
566    // range followed by Jinja’s `|reverse` filter so it works even when
567    // negative‑step ranges aren’t supported.
568    let re = Regex::new(r"range\((?P<expr>[^,]+),\s*-1,\s*-1\)").unwrap();
569    template = re
570        .replace_all(&template, |caps: &regex::Captures| {
571            format!("range({})|reverse", &caps["expr"])
572        })
573        .into_owned();
574
575    if template.contains("{{ meta }}") {
576        // Fix for GLM4 models
577        template = template.replace("{%- set meta = message.get(\"metadata\", \"\") %}", "");
578        template = template.replace("{{ meta }}", "");
579    }
580    if template.contains("{% generation %}") && template.contains("{% endgeneration %}") {
581        // Strip for smollm3 models
582        template = template.replace("{% generation %}", "");
583        template = template.replace("{% endgeneration %}", "");
584    }
585
586    env.add_template("chat_template", &template)?;
587    env.add_function("raise_exception", raise_exception);
588    env.add_filter("tojson", tojson);
589    env.add_function("strftime_now", strftime_now);
590    let tmpl = env.get_template("chat_template").unwrap();
591
592    let date = chrono::Utc::now();
593    let date_string = date.format("%d, %B, %Y").to_string();
594
595    // Convert reasoning effort to string for template
596    let reasoning_effort_str = reasoning_effort.map(|r| r.as_str()).unwrap_or("medium");
597
598    // Detect builtin tools from the tools list
599    // Known builtin tools for GPT-OSS/Harmony format: "browser", "python"
600    // Known builtin tools for Llama 3.x: "wolfram_alpha", "web_search", "brave_search", "python", "code_interpreter"
601    let builtin_tool_names = [
602        "browser",
603        "python",
604        "code_interpreter",
605        "web_search",
606        "brave_search",
607        "wolfram_alpha",
608    ];
609    let builtin_tools: Vec<&str> = tools
610        .iter()
611        .filter_map(|t| {
612            let name = t.function.name.as_str();
613            if builtin_tool_names.contains(&name) {
614                Some(name)
615            } else {
616                None
617            }
618        })
619        .collect();
620
621    let is_gemma4 = is_gemma4_tool_template(&resolved_template);
622
623    let mut rendered = if tools.is_empty() {
624        tmpl.render(context! {
625            messages => new_messages,
626            add_generation_prompt => add_generation_prompt,
627            bos_token => bos_tok,
628            eos_token => eos_tok,
629            unk_token => unk_tok,
630            date_string => date_string,
631            enable_thinking => enable_thinking.unwrap_or(DEFAULT_ENABLE_THINKING),
632            reasoning_effort => reasoning_effort_str,
633        })?
634    } else {
635        tmpl.render(context! {
636            messages => new_messages,
637            add_generation_prompt => add_generation_prompt,
638            bos_token => bos_tok,
639            eos_token => eos_tok,
640            unk_token => unk_tok,
641            xml_tools => tools.clone(), // SmolLM3
642            tools => tools,
643            builtin_tools => builtin_tools,
644            date_string => date_string,
645            enable_thinking => enable_thinking.unwrap_or(DEFAULT_ENABLE_THINKING),
646            reasoning_effort => reasoning_effort_str,
647        })?
648    };
649
650    // Gemma 4 fix: when tool_responses are in a user turn (the correct
651    // format), the template's generation-prompt logic skips `<|turn>model\n`
652    // because it checks `prev_message_type != 'tool_response'`.  But the
653    // training data ALWAYS has `<|turn>model\n` before the model generates.
654    // Append it when the template left it out.
655    if is_gemma4 && add_generation_prompt && rendered.ends_with("<tool_response|>") {
656        rendered.push_str("<|turn>model\n");
657    }
658
659    Ok(rendered)
660}
661
662#[cfg(test)]
663mod tests {
664    use either::Either;
665    use indexmap::IndexMap;
666    use serde_json::Value;
667
668    use super::{
669        apply_chat_template_to, preprocess_gemma4_tool_messages, ChatTemplateValue,
670        GenerationConfig, DEFAULT_ENABLE_THINKING,
671    };
672    use crate::MessageContent;
673
674    fn user_text_message(text: &str) -> IndexMap<String, MessageContent> {
675        IndexMap::from([
676            ("role".to_string(), Either::Left("user".to_string())),
677            ("content".to_string(), Either::Left(text.to_string())),
678        ])
679    }
680
681    #[test]
682    fn unspecified_thinking_enables_template_thinking() {
683        let template = ChatTemplateValue(Either::Left(
684            "{% if enable_thinking is defined and enable_thinking %}<|think|>{% endif %}{{ bos_token }}{{ messages[0]['content'] }}".to_string(),
685        ));
686        let messages = vec![user_text_message("hello")];
687
688        let rendered = apply_chat_template_to(
689            messages,
690            false,
691            None,
692            None,
693            &template,
694            Some("<bos>".to_string()),
695            None,
696            None,
697            vec![],
698        )
699        .unwrap();
700        let enabled = apply_chat_template_to(
701            vec![user_text_message("hello")],
702            false,
703            Some(true),
704            None,
705            &template,
706            Some("<bos>".to_string()),
707            None,
708            None,
709            vec![],
710        )
711        .unwrap();
712
713        const { assert!(DEFAULT_ENABLE_THINKING) };
714        assert_eq!(rendered, "<|think|><bos>hello");
715        assert_eq!(rendered, enabled);
716    }
717
718    #[test]
719    fn generation_config_exposes_sampling_defaults() {
720        let config: GenerationConfig = serde_json::from_str(
721            r#"{
722                "do_sample": true,
723                "temperature": 1.0,
724                "top_k": 32,
725                "top_p": 0.9,
726                "min_p": 0.05,
727                "repetition_penalty": 1.1,
728                "max_new_tokens": 512
729            }"#,
730        )
731        .unwrap();
732
733        let defaults = config.generation_defaults().unwrap();
734        assert_eq!(defaults.do_sample, Some(true));
735        assert_eq!(defaults.temperature, Some(1.0));
736        assert_eq!(defaults.top_k, Some(32));
737        assert_eq!(defaults.top_p, Some(0.9));
738        assert_eq!(defaults.min_p, Some(0.05));
739        assert_eq!(defaults.repetition_penalty, Some(1.1));
740        assert_eq!(defaults.max_new_tokens, Some(512));
741    }
742
743    fn assistant_message_with_tool_calls() -> IndexMap<String, MessageContent> {
744        let mut tc_map = IndexMap::new();
745        tc_map.insert("id".to_string(), Value::String("call-1".to_string()));
746        tc_map.insert("type".to_string(), Value::String("function".to_string()));
747        let mut func = serde_json::Map::new();
748        func.insert("name".to_string(), Value::String("get_weather".to_string()));
749        func.insert(
750            "arguments".to_string(),
751            Value::String(r#"{"city":"Boston"}"#.to_string()),
752        );
753        tc_map.insert("function".to_string(), Value::Object(func));
754
755        IndexMap::from([
756            ("role".to_string(), Either::Left("assistant".to_string())),
757            (
758                "content".to_string(),
759                Either::Left(
760                    r#"{"name":"get_weather","arguments":"{\"city\":\"Boston\"}"}"#.to_string(),
761                ),
762            ),
763            ("tool_calls".to_string(), Either::Right(vec![tc_map])),
764        ])
765    }
766
767    fn tool_result_message(name: &str, content: &str) -> IndexMap<String, MessageContent> {
768        IndexMap::from([
769            ("role".to_string(), Either::Left("tool".to_string())),
770            ("name".to_string(), Either::Left(name.to_string())),
771            ("content".to_string(), Either::Left(content.to_string())),
772        ])
773    }
774
775    #[test]
776    fn gemma4_preprocess_creates_user_msg_for_tool_responses() {
777        let mut messages = vec![
778            user_text_message("What's the weather?"),
779            assistant_message_with_tool_calls(),
780            tool_result_message("get_weather", r#"{"temp":72}"#),
781        ];
782
783        preprocess_gemma4_tool_messages(&mut messages);
784
785        // Tool message replaced by a user message with tool_responses
786        assert_eq!(messages.len(), 3);
787        // Assistant message should NOT have tool_responses
788        assert!(!messages[1].contains_key("tool_responses"));
789        // Content should be cleared (had tool_calls)
790        let content = messages[1].get("content").unwrap();
791        assert_eq!(content, &Either::Left(String::new()));
792        // New user message should have tool_responses
793        let role = messages[2].get("role").unwrap();
794        assert_eq!(role, &Either::Left("user".to_string()));
795        assert!(messages[2].contains_key("tool_responses"));
796    }
797
798    #[test]
799    fn gemma4_preprocess_tool_response_has_correct_structure() {
800        let mut messages = vec![
801            user_text_message("hi"),
802            assistant_message_with_tool_calls(),
803            tool_result_message("get_weather", r#"{"temp":72}"#),
804        ];
805
806        preprocess_gemma4_tool_messages(&mut messages);
807
808        let tool_responses = match messages[2].get("tool_responses").unwrap() {
809            Either::Right(v) => v,
810            _ => panic!("Expected Either::Right"),
811        };
812        assert_eq!(tool_responses.len(), 1);
813        assert_eq!(tool_responses[0]["name"], "get_weather");
814        // Content was valid JSON → parsed into a Value, not a string
815        assert_eq!(tool_responses[0]["response"]["temp"], 72);
816    }
817
818    #[test]
819    fn gemma4_parse_tool_call_arguments_converts_json_string_to_object() {
820        let mut messages = vec![
821            user_text_message("call something"),
822            assistant_message_with_tool_calls(),
823        ];
824        // Before: arguments is a JSON string
825        if let Some(Either::Right(ref tcs)) = messages[1].get("tool_calls") {
826            let func = tcs[0].get("function").unwrap();
827            assert!(func.get("arguments").unwrap().is_string());
828        }
829
830        super::parse_gemma4_tool_call_arguments(&mut messages);
831
832        // After: arguments should be a parsed object
833        if let Some(Either::Right(ref tcs)) = messages[1].get("tool_calls") {
834            let func = tcs[0].get("function").unwrap();
835            let args = func.get("arguments").unwrap();
836            assert!(args.is_object(), "arguments should be parsed to object");
837            assert_eq!(args.get("city").unwrap(), "Boston");
838        } else {
839            panic!("expected tool_calls");
840        }
841    }
842
843    #[test]
844    fn gemma4_preprocess_multiple_tool_messages() {
845        let mut messages = vec![
846            user_text_message("hi"),
847            assistant_message_with_tool_calls(),
848            tool_result_message("get_weather", r#"{"temp":72}"#),
849            tool_result_message("get_forecast", "sunny"),
850        ];
851
852        preprocess_gemma4_tool_messages(&mut messages);
853
854        // assistant + one user msg replaces the two tool msgs
855        assert_eq!(messages.len(), 3);
856        let tool_responses = match messages[2].get("tool_responses").unwrap() {
857            Either::Right(v) => v,
858            _ => panic!("Expected Either::Right"),
859        };
860        assert_eq!(tool_responses.len(), 2);
861        assert_eq!(tool_responses[0]["name"], "get_weather");
862        assert_eq!(tool_responses[1]["name"], "get_forecast");
863        // Non-JSON content falls back to string
864        assert_eq!(tool_responses[1]["response"], "sunny");
865    }
866
867    #[test]
868    fn gemma4_preprocess_no_tool_messages_is_noop() {
869        let mut messages = vec![
870            user_text_message("hello"),
871            IndexMap::from([
872                ("role".to_string(), Either::Left("assistant".to_string())),
873                ("content".to_string(), Either::Left("hi there".to_string())),
874            ]),
875        ];
876        let original_len = messages.len();
877
878        preprocess_gemma4_tool_messages(&mut messages);
879
880        assert_eq!(messages.len(), original_len);
881    }
882
883    #[test]
884    fn gemma4_preprocess_tool_without_name_defaults_to_unknown() {
885        let mut messages = vec![
886            user_text_message("hi"),
887            assistant_message_with_tool_calls(),
888            // Tool message without "name" field
889            IndexMap::from([
890                ("role".to_string(), Either::Left("tool".to_string())),
891                ("content".to_string(), Either::Left("result".to_string())),
892            ]),
893        ];
894
895        preprocess_gemma4_tool_messages(&mut messages);
896
897        let tool_responses = match messages[2].get("tool_responses").unwrap() {
898            Either::Right(v) => v,
899            _ => panic!("Expected Either::Right"),
900        };
901        assert_eq!(tool_responses[0]["name"], "unknown");
902    }
903
904    #[test]
905    fn generation_config_keeps_omitted_sampling_fields_unset() {
906        let config: GenerationConfig = serde_json::from_str(
907            r#"{
908                "do_sample": true,
909                "temperature": 1.0
910            }"#,
911        )
912        .unwrap();
913
914        let defaults = config.generation_defaults().unwrap();
915        assert_eq!(defaults.do_sample, Some(true));
916        assert_eq!(defaults.temperature, Some(1.0));
917        assert_eq!(defaults.top_k, None);
918        assert_eq!(defaults.top_p, None);
919        assert_eq!(defaults.repetition_penalty, None);
920        assert_eq!(defaults.max_new_tokens, None);
921        assert_eq!(defaults.max_length, None);
922    }
923}