1use std::fmt;
14use std::str::FromStr;
15
16use serde::{Deserialize, Serialize};
17use thiserror::Error;
18
19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
25pub struct EmbeddingModelId {
26 pub name: String,
28 pub revision: u32,
30}
31
32impl EmbeddingModelId {
33 pub fn new(name: impl Into<String>, revision: u32) -> Result<Self, ParseEmbeddingModelIdError> {
41 let name = name.into();
42 if name.is_empty() {
43 return Err(ParseEmbeddingModelIdError::EmptyName);
44 }
45 if name.contains('@') {
46 return Err(ParseEmbeddingModelIdError::NameContainsAt);
47 }
48 if revision == 0 {
49 return Err(ParseEmbeddingModelIdError::InvalidRevision("0".to_owned()));
50 }
51 Ok(Self { name, revision })
52 }
53}
54
55impl fmt::Display for EmbeddingModelId {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 write!(f, "{}@{}", self.name, self.revision)
58 }
59}
60
61impl FromStr for EmbeddingModelId {
62 type Err = ParseEmbeddingModelIdError;
63
64 fn from_str(s: &str) -> Result<Self, Self::Err> {
65 let (name, rev) = s
66 .split_once('@')
67 .ok_or(ParseEmbeddingModelIdError::MissingAtSeparator)?;
68 if name.is_empty() {
69 return Err(ParseEmbeddingModelIdError::EmptyName);
70 }
71 if name.contains('@') {
72 return Err(ParseEmbeddingModelIdError::NameContainsAt);
73 }
74 let revision: u32 = rev
75 .parse()
76 .map_err(|_| ParseEmbeddingModelIdError::InvalidRevision(rev.to_owned()))?;
77 if revision == 0 {
78 return Err(ParseEmbeddingModelIdError::InvalidRevision(rev.to_owned()));
79 }
80 Ok(Self {
81 name: name.to_owned(),
82 revision,
83 })
84 }
85}
86
87impl Serialize for EmbeddingModelId {
88 fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
89 ser.collect_str(self)
90 }
91}
92
93impl<'de> Deserialize<'de> for EmbeddingModelId {
94 fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
95 let s = String::deserialize(de)?;
96 s.parse().map_err(serde::de::Error::custom)
97 }
98}
99
100#[derive(Debug, Error, PartialEq, Eq)]
102pub enum ParseEmbeddingModelIdError {
103 #[error("missing `@` separator (expected `name@revision`)")]
105 MissingAtSeparator,
106 #[error("model name is empty")]
108 EmptyName,
109 #[error("model name contains `@` (use the separator)")]
111 NameContainsAt,
112 #[error("invalid revision `{0}` (expected positive integer)")]
114 InvalidRevision(String),
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120
121 #[test]
122 fn parses_canonical_form() {
123 let id: EmbeddingModelId = "bge-base-en-v1.5@1".parse().unwrap();
124 assert_eq!(id.name, "bge-base-en-v1.5");
125 assert_eq!(id.revision, 1);
126 }
127
128 #[test]
129 fn round_trips() {
130 let id = EmbeddingModelId::new("voyage-code-3", 42).unwrap();
131 let s = id.to_string();
132 let back: EmbeddingModelId = s.parse().unwrap();
133 assert_eq!(id, back);
134 }
135
136 #[test]
137 fn rejects_missing_at() {
138 assert_eq!(
139 "bge-base-en-v1.5".parse::<EmbeddingModelId>(),
140 Err(ParseEmbeddingModelIdError::MissingAtSeparator),
141 );
142 }
143
144 #[test]
145 fn rejects_empty_name() {
146 assert_eq!("@1".parse::<EmbeddingModelId>(), Err(ParseEmbeddingModelIdError::EmptyName),);
147 }
148
149 #[test]
150 fn rejects_zero_revision() {
151 match "name@0".parse::<EmbeddingModelId>() {
152 Err(ParseEmbeddingModelIdError::InvalidRevision(s)) => assert_eq!(s, "0"),
153 other => panic!("expected InvalidRevision, got {other:?}"),
154 }
155 }
156
157 #[test]
158 fn rejects_negative_revision() {
159 assert!(matches!(
160 "name@-1".parse::<EmbeddingModelId>(),
161 Err(ParseEmbeddingModelIdError::InvalidRevision(_)),
162 ));
163 }
164
165 #[test]
166 fn rejects_non_numeric_revision() {
167 assert!(matches!(
168 "name@v1".parse::<EmbeddingModelId>(),
169 Err(ParseEmbeddingModelIdError::InvalidRevision(_)),
170 ));
171 }
172
173 #[test]
174 fn rejects_overflow_revision() {
175 let big = format!("name@{}", u64::from(u32::MAX) + 1);
176 assert!(matches!(
177 big.parse::<EmbeddingModelId>(),
178 Err(ParseEmbeddingModelIdError::InvalidRevision(_)),
179 ));
180 }
181
182 #[test]
183 fn serde_uses_string_form() {
184 let id = EmbeddingModelId::new("bge-base-en-v1.5", 1).unwrap();
185 let j = serde_json::to_value(&id).unwrap();
186 assert_eq!(j, serde_json::Value::String("bge-base-en-v1.5@1".into()));
187 let back: EmbeddingModelId = serde_json::from_value(j).unwrap();
188 assert_eq!(id, back);
189 }
190}