use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::{Duration, Instant};
use crate::script_data::{List, ScriptData, get_normalized_script_name};
use crate::transliterate::transliterate::{
TransliterationFnOptions, resolve_transliteration_rules, transliterate_text_core,
};
pub const DEFAULT_AUTO_CONTEXT_CLEAR_TIME_MS: u64 = 4500;
pub const DEFAULT_USE_NATIVE_NUMERALS: bool = true;
pub const DEFAULT_INCLUDE_INHERENT_VOWEL: bool = false;
#[derive(Debug, Clone)]
pub struct TypingContextOptions {
pub auto_context_clear_time_ms: u64,
pub use_native_numerals: bool,
pub include_inherent_vowel: bool,
}
impl Default for TypingContextOptions {
fn default() -> Self {
Self {
auto_context_clear_time_ms: DEFAULT_AUTO_CONTEXT_CLEAR_TIME_MS,
use_native_numerals: DEFAULT_USE_NATIVE_NUMERALS,
include_inherent_vowel: DEFAULT_INCLUDE_INHERENT_VOWEL,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TypingDiff {
pub to_delete_chars_count: usize,
pub diff_add_text: String,
pub context_length: usize,
}
#[derive(Debug)]
pub struct TypingContext {
normalized_typing_lang: String,
use_native_numerals: bool,
include_inherent_vowel: bool,
curr_input: String,
curr_output: String,
auto_context_clear_time: Duration,
last_time: Option<Instant>,
from_script_data: &'static ScriptData,
to_script_data: &'static ScriptData,
trans_options: HashMap<String, bool>,
custom_rules: Vec<crate::script_data::Rule>,
}
impl TypingContext {
pub fn new(typing_lang: &str, options: Option<TypingContextOptions>) -> Result<Self, String> {
let opts = options.unwrap_or_default();
let normalized_typing_lang = get_normalized_script_name(typing_lang)
.ok_or_else(|| format!("Invalid script name: {}", typing_lang))?;
let from_script_data = ScriptData::get_script_data("Normal");
let to_script_data = ScriptData::get_script_data(&normalized_typing_lang);
let resolved = resolve_transliteration_rules(from_script_data, to_script_data, None);
Ok(Self {
normalized_typing_lang,
use_native_numerals: opts.use_native_numerals,
include_inherent_vowel: opts.include_inherent_vowel,
curr_input: String::new(),
curr_output: String::new(),
auto_context_clear_time: Duration::from_millis(opts.auto_context_clear_time_ms),
last_time: None,
from_script_data,
to_script_data,
trans_options: resolved.trans_options,
custom_rules: resolved.custom_rules,
})
}
pub fn clear_context(&mut self) {
self.last_time = None;
self.curr_input.clear();
self.curr_output.clear();
}
fn build_translit_options(&self) -> TransliterationFnOptions {
TransliterationFnOptions {
typing_mode: true,
use_native_numerals: self.use_native_numerals,
include_inherent_vowel: self.include_inherent_vowel,
}
}
pub fn take_key_input(&mut self, key: &str) -> Result<TypingDiff, String> {
let Some(ch) = key.chars().next() else {
return Ok(TypingDiff {
to_delete_chars_count: 0,
diff_add_text: String::new(),
context_length: 0,
});
};
let now = Instant::now();
if let Some(last) = self.last_time {
if now.duration_since(last) > self.auto_context_clear_time {
self.clear_context();
}
}
self.curr_input.push(ch);
let prev_output = self.curr_output.clone();
let result = transliterate_text_core(
self.curr_input.clone(),
"Normal",
&self.normalized_typing_lang,
self.from_script_data,
self.to_script_data,
&self.trans_options,
&self.custom_rules,
Some(self.build_translit_options()),
)?;
let context_length = result.context_length;
let output = result.output;
if context_length > 0 {
self.curr_output = output.clone();
} else if context_length == 0 {
self.clear_context();
}
let (to_delete_chars_count, diff_add_text) = compute_diff(&prev_output, &output);
self.last_time = Some(Instant::now());
Ok(TypingDiff {
to_delete_chars_count,
diff_add_text,
context_length,
})
}
pub fn update_use_native_numerals(&mut self, use_native_numerals: bool) {
self.use_native_numerals = use_native_numerals;
}
pub fn update_include_inherent_vowel(&mut self, include_inherent_vowel: bool) {
self.include_inherent_vowel = include_inherent_vowel;
}
pub fn get_use_native_numerals(&self) -> bool {
self.use_native_numerals
}
pub fn get_include_inherent_vowel(&self) -> bool {
self.include_inherent_vowel
}
pub fn get_normalized_script(&self) -> String {
self.normalized_typing_lang.clone()
}
}
fn compute_diff(prev_output: &str, output: &str) -> (usize, String) {
let mut common_chars = 0usize;
for (a, b) in prev_output.chars().zip(output.chars()) {
if a != b {
break;
}
common_chars += 1;
}
let to_delete_chars_count = prev_output.chars().count().saturating_sub(common_chars);
let diff_add_text: String = output.chars().skip(common_chars).collect();
(to_delete_chars_count, diff_add_text)
}
pub fn emulate_typing(
text: &str,
typing_lang: &str,
options: Option<TypingContextOptions>,
) -> Result<String, String> {
let mut ctx = TypingContext::new(typing_lang, options)?;
let mut result = String::new();
for ch in text.chars() {
let diff = ctx.take_key_input(&ch.to_string())?;
if diff.to_delete_chars_count > 0 {
truncate_last_chars(&mut result, diff.to_delete_chars_count);
}
result.push_str(&diff.diff_add_text);
}
Ok(result)
}
fn truncate_last_chars(s: &mut String, n: usize) {
for _ in 0..n {
if let Some((idx, _)) = s.char_indices().rev().next() {
s.truncate(idx);
} else {
break;
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ListType {
Anya,
Vyanjana,
Matra,
Svara,
}
impl ListType {
fn from_list(list: &List) -> Self {
match list {
List::Anya { .. } => ListType::Anya,
List::Vyanjana { .. } => ListType::Vyanjana,
List::Matra { .. } => ListType::Matra,
List::Svara { .. } => ListType::Svara,
}
}
}
pub type TypingDataMapItem = (String, ListType, Vec<String>);
#[derive(Debug)]
pub struct ScriptTypingDataMap {
pub common_krama_map: Vec<TypingDataMapItem>,
pub script_specific_krama_map: Vec<TypingDataMapItem>,
}
pub fn get_script_typing_data_map(script: &str) -> Result<ScriptTypingDataMap, String> {
let normalized_typing_lang =
get_normalized_script_name(script).ok_or_else(|| format!("Invalid script name: {}", script))?;
if normalized_typing_lang == "Normal" {
return Err(format!("Invalid script name: {}", script));
}
let script_data = ScriptData::get_script_data(&normalized_typing_lang);
let common_attr = script_data.get_common_attr();
fn merge_duplicate_text_mappings(items: Vec<TypingDataMapItem>) -> Vec<TypingDataMapItem> {
use std::collections::{HashMap, HashSet};
let mut key_to_index: HashMap<(String, ListType), usize> = HashMap::new();
let mut mapping_sets: Vec<HashSet<String>> = Vec::new();
let mut out: Vec<TypingDataMapItem> = Vec::new();
for (text, list_type, mappings) in items {
let key = (text.clone(), list_type.clone());
if let Some(&existing_index) = key_to_index.get(&key) {
let set = &mut mapping_sets[existing_index];
let target_mappings = &mut out[existing_index].2;
for m in mappings {
if set.insert(m.clone()) {
target_mappings.push(m);
}
}
} else {
let mut set = HashSet::new();
let mut uniq = Vec::new();
for m in mappings {
if set.insert(m.clone()) {
uniq.push(m);
}
}
key_to_index.insert(key, out.len());
out.push((text, list_type, uniq));
mapping_sets.push(set);
}
}
out
.into_iter()
.filter(|(_, _, mappings)| !mappings.is_empty())
.collect()
}
let mut common_krama_map: Vec<TypingDataMapItem> = common_attr
.krama_text_arr
.iter()
.map(|(text, list_index)| {
let list_type = list_index
.and_then(|idx| common_attr.list.get(idx as usize))
.map(ListType::from_list)
.unwrap_or(ListType::Anya);
(text.clone(), list_type, Vec::new())
})
.collect();
let mut script_specific_krama_map: Vec<TypingDataMapItem> = common_attr
.custom_script_chars_arr
.iter()
.map(|(text, list_index, _)| {
let list_type = list_index
.and_then(|idx| common_attr.list.get(idx as usize))
.map(ListType::from_list)
.unwrap_or(ListType::Anya);
(text.clone(), list_type, Vec::new())
})
.collect();
for (normal_text_map, item) in &common_attr.typing_text_to_krama_map {
if normal_text_map.is_empty() {
continue;
}
if let Some(custom_back_ref) = item.custom_back_ref {
if let Some(entry) = script_specific_krama_map.get_mut(custom_back_ref as usize) {
entry.2.push(normal_text_map.clone());
}
} else if let Some(ref krama) = item.krama {
if krama.len() == 1 {
let krama_index = krama[0];
if krama_index >= 0 {
if let Some(entry) = common_krama_map.get_mut(krama_index as usize) {
entry.2.push(normal_text_map.clone());
}
}
}
}
}
common_krama_map = merge_duplicate_text_mappings(common_krama_map);
script_specific_krama_map = merge_duplicate_text_mappings(script_specific_krama_map);
Ok(ScriptTypingDataMap {
common_krama_map,
script_specific_krama_map,
})
}
pub type KramaDataItem = (String, ListType);
pub fn get_script_krama_data(script: &str) -> Result<Vec<KramaDataItem>, String> {
let normalized =
get_normalized_script_name(script).ok_or_else(|| format!("Invalid script name: {}", script))?;
if normalized == "Normal" {
return Err(format!("Invalid script name: {}", script));
}
let script_data = ScriptData::get_script_data(&normalized);
let common_attr = script_data.get_common_attr();
Ok(
common_attr
.krama_text_arr
.iter()
.map(|(text, list_idx)| {
let list_type = list_idx
.and_then(|idx| common_attr.list.get(idx as usize))
.map(ListType::from_list)
.unwrap_or(ListType::Anya);
(text.clone(), list_type)
})
.collect(),
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transliterate::helpers::VEDIC_SVARAS;
use serde::Deserialize;
use std::fs;
use std::fs::OpenOptions;
use std::io::Write;
use std::path::{Path, PathBuf};
fn de_index<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: serde::Deserializer<'de>,
{
struct IndexVisitor;
impl serde::de::Visitor<'_> for IndexVisitor {
type Value = String;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a yaml index (number or string)")
}
fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(v.to_string())
}
fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(v.to_string())
}
fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(v.to_string())
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(v.to_string())
}
fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(v)
}
}
deserializer.deserialize_any(IndexVisitor)
}
#[derive(Debug, Deserialize)]
struct TransliterationTestCase {
#[serde(deserialize_with = "de_index")]
#[allow(dead_code)]
index: String,
from: String,
to: String,
input: String,
output: String,
#[serde(default)]
#[allow(dead_code)]
options: Option<std::collections::HashMap<String, bool>>,
#[serde(default)]
#[allow(dead_code)]
reversible: Option<bool>,
#[serde(default)]
todo: Option<bool>,
}
fn transliteration_test_data_root() -> PathBuf {
let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
manifest_dir
.join("..")
.join("..")
.join("test_data")
.join("transliteration")
}
#[test]
fn emulate_typing_auto_transliteration_yaml() {
use serde_yaml_ng as yaml;
let root = transliteration_test_data_root();
let input_dirs = [root.join("auto-nor-brahmic"), root.join("auto-nor-other")];
let mut total_emulations: usize = 0;
let mut auto_vedic_skipped: usize = 0;
for folder in &input_dirs {
let entries = fs::read_dir(folder)
.unwrap_or_else(|e| panic!("Failed listing YAML files in `{}`: {e}", folder.display()));
for entry in entries {
let entry = entry.expect("Failed to read directory entry");
let path = entry.path();
if path.extension().and_then(|e| e.to_str()) != Some("yaml") {
continue;
}
let yaml_text = fs::read_to_string(&path)
.unwrap_or_else(|e| panic!("Failed reading YAML file `{}`: {e}", path.display()));
let cases: Vec<TransliterationTestCase> = yaml::from_str(&yaml_text)
.unwrap_or_else(|e| panic!("Failed parsing YAML file `{}`: {e}", path.display()));
let file_name = path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("<unknown>")
.to_string();
for case in cases {
if case.todo.unwrap_or(false) {
continue;
}
if case.from != "Normal" || case.to == "Normal" {
continue;
}
let input = &case.input;
total_emulations += 1;
let result = emulate_typing(input, &case.to, None)
.unwrap_or_else(|e| panic!("emulate_typing error for {}: {}", path.display(), e));
let error_message = format!(
"Emulate Typing failed:\n From: {}\n To: {}\n Input: \"{}\"\n Expected: \"{}\"\n Actual: \"{}\"",
case.from, case.to, case.input, case.output, result
);
if file_name.starts_with("auto")
&& case.to == "Tamil-Extended"
&& VEDIC_SVARAS.iter().any(|sv| result.contains(*sv))
{
auto_vedic_skipped += 1;
continue;
}
assert_eq!(result, case.output, "{}", error_message);
}
}
}
let passed = total_emulations.saturating_sub(auto_vedic_skipped);
let summary = format!(
"Emulate Typing (auto transliteration): total_emulations={}, auto_vedic_skipped={}, passed={}",
total_emulations, auto_vedic_skipped, passed
);
println!("{}", summary);
let _ = std::fs::create_dir_all("test_log");
if let Ok(mut file) = OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open("test_log/typing_auto_emulate_log.txt")
{
let _ = writeln!(file, "{}", summary);
}
}
#[derive(Debug, Deserialize, Default)]
struct TypingOptionsYaml {
#[serde(rename = "useNativeNumerals")]
#[serde(default)]
use_native_numerals: Option<bool>,
#[serde(rename = "includeInherentVowel")]
#[serde(default)]
include_inherent_vowel: Option<bool>,
#[serde(rename = "autoContextTClearTimeMs")]
#[serde(default)]
auto_context_clear_time_ms: Option<u64>,
}
#[derive(Debug, Deserialize)]
struct TypingTestCase {
index: i64,
text: String,
output: String,
script: String,
#[serde(default)]
preserve_check: bool,
#[serde(default)]
todo: bool,
#[serde(default)]
options: Option<TypingOptionsYaml>,
}
fn typing_test_data_root() -> PathBuf {
let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
manifest_dir
.join("..")
.join("..")
.join("test_data")
.join("typing")
}
fn list_yaml_files_typing(dir: &Path, out: &mut Vec<PathBuf>) -> std::io::Result<()> {
for entry in fs::read_dir(dir)? {
let entry = entry?;
let path = entry.path();
if path.is_dir() {
if path
.file_name()
.and_then(|n| n.to_str())
.is_some_and(|n| n == "context")
{
continue;
}
list_yaml_files_typing(&path, out)?;
} else if path.extension().and_then(|e| e.to_str()) == Some("yaml") {
out.push(path);
}
}
Ok(())
}
fn build_typing_options(opts: &Option<TypingOptionsYaml>) -> Option<TypingContextOptions> {
let some_opts = match opts {
None => return None,
Some(o) => o,
};
let mut rust_opts = TypingContextOptions::default();
if let Some(v) = some_opts.use_native_numerals {
rust_opts.use_native_numerals = v;
}
if let Some(v) = some_opts.include_inherent_vowel {
rust_opts.include_inherent_vowel = v;
}
if let Some(v) = some_opts.auto_context_clear_time_ms {
rust_opts.auto_context_clear_time_ms = v;
}
Some(rust_opts)
}
#[test]
fn typing_mode_yaml_tests() {
use serde_yaml_ng as yaml;
let root = typing_test_data_root();
let mut files: Vec<PathBuf> = Vec::new();
list_yaml_files_typing(&root, &mut files)
.unwrap_or_else(|e| panic!("Failed listing YAML files in `{}`: {e}", root.display()));
files.sort();
assert!(
!files.is_empty(),
"No YAML typing test files found in `{}`",
root.display()
);
let mut total_emulations: usize = 0;
let mut preserve_checks: usize = 0;
for file in files {
let yaml_text = fs::read_to_string(&file)
.unwrap_or_else(|e| panic!("Failed reading YAML file `{}`: {e}", file.display()));
let cases: Vec<TypingTestCase> = yaml::from_str(&yaml_text)
.unwrap_or_else(|e| panic!("Failed parsing YAML file `{}`: {e}", file.display()));
for case in cases {
if case.todo {
continue;
}
let opts = build_typing_options(&case.options);
total_emulations += 1;
let result = emulate_typing(&case.text, &case.script, opts.clone()).unwrap_or_else(|e| {
panic!(
"emulate_typing error in `{}` index {}: {}",
file.display(),
case.index,
e
)
});
assert_eq!(
result,
case.output,
"Typing Mode failed in `{}` index {} (script {}): input {:?}",
file.display(),
case.index,
case.script,
case.text
);
if case.preserve_check {
preserve_checks += 1;
let mut trans_options = std::collections::HashMap::new();
trans_options.insert("all_to_normal:preserve_specific_chars".to_string(), true);
let preserved =
crate::transliterate(&result, &case.script, "Normal", Some(&trans_options))
.unwrap_or_else(|e| {
panic!(
"transliterate (preserve check) error in `{}` index {}: {}",
file.display(),
case.index,
e
)
});
assert_eq!(
preserved,
case.text,
"Preserve check failed in `{}` index {} (script {}): input {:?}",
file.display(),
case.index,
case.script,
case.text
);
}
}
}
let summary = format!(
"Typing Mode: total_emulations={}, preserve_checks={}, passed={}",
total_emulations, preserve_checks, total_emulations
);
println!("{}", summary);
let _ = std::fs::create_dir_all("test_log");
if let Ok(mut file) = OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open("test_log/typing_mode_log.txt")
{
let _ = writeln!(file, "{}", summary);
}
}
#[test]
fn test_get_script_typing_data_map_valid_script() {
let result = get_script_typing_data_map("Devanagari");
assert!(result.is_ok());
let data = result.unwrap();
assert!(!data.common_krama_map.is_empty());
for (text, _list_type, _mappings) in &data.common_krama_map {
assert!(!text.is_empty());
}
}
#[test]
fn test_get_script_typing_data_map_normalized_names() {
let result = get_script_typing_data_map("dev");
assert!(result.is_ok());
}
#[test]
fn test_get_script_typing_data_map_invalid_script() {
let result = get_script_typing_data_map("InvalidScript");
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "Invalid script name: InvalidScript");
}
#[test]
fn test_get_script_typing_data_map_normal_script() {
let result = get_script_typing_data_map("Normal");
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "Invalid script name: Normal");
}
#[test]
fn test_get_script_typing_data_map_mappings_populated() {
let result = get_script_typing_data_map("Telugu");
assert!(result.is_ok());
let data = result.unwrap();
let has_mappings = data
.common_krama_map
.iter()
.any(|(_, _, mappings)| !mappings.is_empty());
assert!(
has_mappings,
"Expected at least some characters to have typing mappings"
);
}
}