mlxrs 0.1.0

Safe Rust bindings for Apple's MLX array framework, with LM, VLM, audio, and embeddings support
//! Full-suite tokenizer tests (`--features lm`).
//!
//! Mirrors the existing `mlxrs/tests` style: integration tests reachable from
//! outside the crate. Covers encode/decode round-trip on an embedded fixture,
//! each streaming detokenizer, realistic HF-style chat-template rendering,
//! tool-parser parsing for json + pythonic, and `infer_tool_parser`. These
//! exercise the full feature surface and so gate on the `lm` umbrella.
#![cfg(feature = "lm")]

use std::io::Write;

use mlxrs::tokenizer::{
  StreamingDetokenizer, Tokenizer,
  stream::{BpeStreamingDetokenizer, NaiveStreamingDetokenizer, SpmStreamingDetokenizer},
  tools::{Glm47, JsonTools, Pythonic, ToolParser},
};
use serde_json::json;

/// A minimal valid HF `tokenizer.json` (WordLevel model, whitespace
/// pre-tokenizer), regenerated by `cargo xtask-codegen` from the canonical
/// `tokenizers` crate so the JSON is schema-correct by construction, then
/// committed under `tests/fixtures/`.
const TOKENIZER_JSON: &str = include_str!("fixtures/tokenizer.json");

/// The matching `tokenizer_config.json`, regenerated by `cargo xtask-codegen`
/// from typed `serde_json` (no hand-authored JSON blob) and committed under
/// `tests/fixtures/`.
const TOKENIZER_CONFIG_JSON: &str = include_str!("fixtures/tokenizer_config.json");

/// Materialize the (deterministic, byte-identical-every-call) fixture
/// exactly once per test process. `cargo test` runs the tests in this
/// binary in parallel; creating + truncating + rewriting the same files on
/// every call races a concurrent `Tokenizer::from_path` read. The content
/// is constant, so a write-once `OnceLock` guard removes the race entirely
/// (every test then reads the same already-complete files) without
/// serializing the tests.
fn fixture_dir() -> std::path::PathBuf {
  static FIXTURE: std::sync::OnceLock<std::path::PathBuf> = std::sync::OnceLock::new();
  FIXTURE
    .get_or_init(|| {
      let dir = std::env::temp_dir().join(format!("mlxrs-tok-fixture-{}", std::process::id()));
      std::fs::create_dir_all(&dir).unwrap();
      let mut f = std::fs::File::create(dir.join("tokenizer.json")).unwrap();
      f.write_all(TOKENIZER_JSON.as_bytes()).unwrap();
      let mut c = std::fs::File::create(dir.join("tokenizer_config.json")).unwrap();
      c.write_all(TOKENIZER_CONFIG_JSON.as_bytes()).unwrap();
      dir
    })
    .clone()
}

#[test]
fn encode_decode_round_trip() {
  let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
  let ids = tok
    .encode("hello world the quick brown fox", false)
    .unwrap();
  assert_eq!(ids, vec![3, 4, 5, 6, 7, 8]);
  let text = tok.decode(&ids, false).unwrap();
  assert!(text.contains("hello"));
  assert!(text.contains("fox"));
  // Round-trip: re-encode the decoded text yields the same ids.
  let reids = tok.encode(&text, false).unwrap();
  assert_eq!(reids, ids);
}

#[test]
fn special_token_properties() {
  let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
  assert_eq!(tok.bos_token(), Some("<s>"));
  assert_eq!(tok.eos_token(), Some("</s>"));
  assert_eq!(tok.unk_token(), Some("<unk>"));
  assert_eq!(tok.bos_token_id(), Some(1));
  assert_eq!(tok.eos_token_id(), Some(2));
  assert!(tok.contains_eos_id(2));
  assert!(tok.has_chat_template());
  // _infer_thinking: <think>/</think> are in vocab.
  assert!(tok.has_thinking());
  assert_eq!(tok.think_start(), Some("<think>"));
  assert_eq!(tok.think_start_tokens(), Some(&[9u32][..]));
}

#[test]
fn naive_detokenizer_reconstructs_full_decode() {
  let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
  let ids = tok.encode("hello world the quick", false).unwrap();
  let full = tok.decode(&ids, false).unwrap();

  let mut d = tok.detokenizer();
  d.reset();
  for &t in &ids {
    d.add_token(t);
  }
  d.finalize();
  assert_eq!(d.text(), full);
  assert_eq!(d.tokens(), ids.as_slice());
}

#[test]
fn spm_detokenizer_incremental_matches_join() {
  // SPM uses the ▁ separator; build a tiny id->piece map.
  let vocab = vec![
    ("\u{2581}Hello".to_string(), 0u32),
    ("\u{2581}world".to_string(), 1u32),
    ("!".to_string(), 2u32),
  ];
  let mut d = SpmStreamingDetokenizer::new(vocab, true);
  d.reset();
  for t in [0u32, 1, 2] {
    d.add_token(t);
  }
  d.finalize();
  assert_eq!(d.text(), "Hello world!");
}

#[test]
fn bpe_detokenizer_incremental_matches_join() {
  // GPT-2 byte-level: 'Ġ' (U+0120) maps to a space.
  let vocab = vec![
    ("Hello".to_string(), 0u32),
    ("\u{0120}world".to_string(), 1u32),
  ];
  let mut d = BpeStreamingDetokenizer::new(vocab, false);
  d.reset();
  d.add_token(0);
  d.add_token(1);
  d.finalize();
  assert_eq!(d.text(), "Hello world");
}

#[test]
fn naive_detokenizer_standalone_matches() {
  let decode = |ids: &[u32]| ids.iter().map(|i| format!("t{i} ")).collect::<String>();
  let full = decode(&[1, 2, 3]);
  let mut d = NaiveStreamingDetokenizer::new(decode, false);
  d.reset();
  for t in [1u32, 2, 3] {
    d.add_token(t);
  }
  d.finalize();
  assert_eq!(d.text(), full);
}

#[test]
fn chat_template_render_with_generation_prompt_and_tool() {
  let tok = Tokenizer::from_path(fixture_dir(), None).unwrap();
  let messages = json!([
    {"role": "user", "content": "hi"},
    {"role": "assistant", "content": "yo"}
  ]);
  let tools = json!([
    {"type": "function", "function": {"name": "get_weather", "parameters": {}}}
  ]);
  let out = tok
    .apply_chat_template(&messages, Some(&tools), true, false, None)
    .unwrap();
  assert_eq!(
    out,
    "<s><|user|>hi<|assistant|>yo<tool>get_weather</tool><|assistant|>"
  );

  // Without generation prompt, the trailing assistant marker is absent.
  let out2 = tok
    .apply_chat_template(&messages, None, false, false, None)
    .unwrap();
  assert_eq!(out2, "<s><|user|>hi<|assistant|>yo");
}

#[test]
fn json_tools_parser_parses_assistant_tool_call() {
  let p = JsonTools;
  let calls = p
    .parse(
      r#"{"name": "get_weather", "arguments": {"city": "Paris"}}"#,
      None,
    )
    .unwrap();
  assert_eq!(calls.len(), 1);
  assert_eq!(calls[0].name(), "get_weather");
  assert_eq!(*calls[0].arguments(), json!({"city": "Paris"}));
}

#[test]
fn pythonic_parser_parses_assistant_tool_call() {
  let p = Pythonic;
  let calls = p
    .parse(
      r#"<|tool_call_start|>[get_weather(city="Paris", days=3)]<|tool_call_end|>"#,
      None,
    )
    .unwrap();
  assert_eq!(calls.len(), 1);
  assert_eq!(calls[0].name(), "get_weather");
  assert_eq!(calls[0].arguments()["city"], json!("Paris"));
  assert_eq!(calls[0].arguments()["days"], json!(3));
}

#[test]
fn glm47_parser_parses_xml_style() {
  let p = Glm47;
  let calls = p
    .parse(
      "get_weather<arg_key>city</arg_key><arg_value>Paris</arg_value>",
      None,
    )
    .unwrap();
  assert_eq!(calls[0].name(), "get_weather");
  assert_eq!(calls[0].arguments()["city"], json!("Paris"));
}

/// Direct unit test for the [`Tokenizer::additional_special_token_ids`]
/// accessor — used by `lm::gguf::HfVocab` to union `tokenizer_config.json`
/// `additional_special_tokens` ids into the special-token set, mirroring
/// HF's `PreTrainedTokenizerBase.additional_special_tokens_ids`.
///
/// The fixture reuses the shared `tokenizer.json` (so `<think>` at id 9 and
/// `</think>` at id 10 are present), but plants a one-off
/// `tokenizer_config.json` in a fresh dir that declares both as
/// `additional_special_tokens` — one as a plain string and one as the
/// `AddedToken` object shape (`{"content": ...}`) — to cover both JSON
/// shapes the accessor accepts.
#[test]
fn additional_special_token_ids_unions_string_and_object_forms() {
  let dir = std::env::temp_dir().join(format!(
    "mlxrs-tok-addl-{}-{}",
    std::process::id(),
    std::time::SystemTime::now()
      .duration_since(std::time::UNIX_EPOCH)
      .map(|d| d.as_nanos())
      .unwrap_or(0)
  ));
  std::fs::create_dir_all(&dir).unwrap();
  std::fs::write(dir.join("tokenizer.json"), TOKENIZER_JSON).unwrap();
  let cfg = json!({
    "bos_token": "<s>",
    "eos_token": "</s>",
    "unk_token": "<unk>",
    "additional_special_tokens": [
      "<think>",
      { "content": "</think>", "lstrip": false, "rstrip": false, "single_word": false, "normalized": false, "special": true },
      "not-in-vocab"  // silently skipped — `token_to_id` returns None
    ]
  });
  std::fs::write(
    dir.join("tokenizer_config.json"),
    serde_json::to_string(&cfg).unwrap(),
  )
  .unwrap();

  let tok = Tokenizer::from_path(&dir, None).unwrap();
  let mut ids = tok.additional_special_token_ids();
  ids.sort_unstable();
  assert_eq!(ids, vec![9u32, 10u32]);

  let _ = std::fs::remove_dir_all(&dir);
}

#[test]
fn infer_tool_parser_selects_correctly() {
  use mlxrs::tokenizer::infer_tool_parser;
  assert_eq!(
    infer_tool_parser(Some("... [TOOL_CALLS] ...")),
    Some("mistral")
  );
  assert_eq!(
    infer_tool_parser(Some("uses <minimax:tool_call> here")),
    Some("minimax_m2")
  );
  assert_eq!(
    infer_tool_parser(Some("calls <tool_call> with tool_call.name field")),
    Some("json_tools")
  );
  assert_eq!(
    infer_tool_parser(Some("<|tool_list_start|>")),
    Some("pythonic")
  );
  assert_eq!(infer_tool_parser(Some("no markers")), None);
  assert_eq!(infer_tool_parser(None), None);
}