1use std::collections::HashMap;
4
5#[derive(Debug, Clone, Default)]
7pub struct ModelAliases {
8 map: HashMap<String, String>,
9}
10
11impl ModelAliases {
12 pub fn new() -> Self {
14 Self::default()
15 }
16
17 pub fn with_defaults() -> Self {
20 let mut m = Self::new();
21 m.set("fast", "gpt-4o-mini");
22 m.set("smart", "claude-sonnet-4-6");
23 m.set("cheap", "gpt-4o-mini");
24 m.set("local", "llama3.1");
25 m
26 }
27
28 pub fn set(&mut self, alias: impl Into<String>, model: impl Into<String>) {
30 self.map.insert(alias.into(), model.into());
31 }
32
33 pub fn resolve<'a>(&'a self, name: &'a str) -> &'a str {
35 self.map.get(name).map(String::as_str).unwrap_or(name)
36 }
37}
38
39#[cfg(test)]
40mod tests {
41 use super::*;
42
43 #[test]
44 fn resolves_defaults_and_passthrough() {
45 let a = ModelAliases::with_defaults();
46 assert_eq!(a.resolve("fast"), "gpt-4o-mini");
47 assert_eq!(a.resolve("smart"), "claude-sonnet-4-6");
48 assert_eq!(a.resolve("gpt-4o"), "gpt-4o");
49 }
50}