use std::fmt;
use std::str::FromStr;
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct EmbeddingModelId {
pub name: String,
pub revision: u32,
}
impl EmbeddingModelId {
pub fn new(name: impl Into<String>, revision: u32) -> Result<Self, ParseEmbeddingModelIdError> {
let name = name.into();
if name.is_empty() {
return Err(ParseEmbeddingModelIdError::EmptyName);
}
if name.contains('@') {
return Err(ParseEmbeddingModelIdError::NameContainsAt);
}
if revision == 0 {
return Err(ParseEmbeddingModelIdError::InvalidRevision("0".to_owned()));
}
Ok(Self { name, revision })
}
}
impl fmt::Display for EmbeddingModelId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}@{}", self.name, self.revision)
}
}
impl FromStr for EmbeddingModelId {
type Err = ParseEmbeddingModelIdError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let (name, rev) = s
.split_once('@')
.ok_or(ParseEmbeddingModelIdError::MissingAtSeparator)?;
if name.is_empty() {
return Err(ParseEmbeddingModelIdError::EmptyName);
}
if name.contains('@') {
return Err(ParseEmbeddingModelIdError::NameContainsAt);
}
let revision: u32 = rev
.parse()
.map_err(|_| ParseEmbeddingModelIdError::InvalidRevision(rev.to_owned()))?;
if revision == 0 {
return Err(ParseEmbeddingModelIdError::InvalidRevision(rev.to_owned()));
}
Ok(Self {
name: name.to_owned(),
revision,
})
}
}
impl Serialize for EmbeddingModelId {
fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
ser.collect_str(self)
}
}
impl<'de> Deserialize<'de> for EmbeddingModelId {
fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
let s = String::deserialize(de)?;
s.parse().map_err(serde::de::Error::custom)
}
}
#[derive(Debug, Error, PartialEq, Eq)]
pub enum ParseEmbeddingModelIdError {
#[error("missing `@` separator (expected `name@revision`)")]
MissingAtSeparator,
#[error("model name is empty")]
EmptyName,
#[error("model name contains `@` (use the separator)")]
NameContainsAt,
#[error("invalid revision `{0}` (expected positive integer)")]
InvalidRevision(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_canonical_form() {
let id: EmbeddingModelId = "bge-base-en-v1.5@1".parse().unwrap();
assert_eq!(id.name, "bge-base-en-v1.5");
assert_eq!(id.revision, 1);
}
#[test]
fn round_trips() {
let id = EmbeddingModelId::new("voyage-code-3", 42).unwrap();
let s = id.to_string();
let back: EmbeddingModelId = s.parse().unwrap();
assert_eq!(id, back);
}
#[test]
fn rejects_missing_at() {
assert_eq!(
"bge-base-en-v1.5".parse::<EmbeddingModelId>(),
Err(ParseEmbeddingModelIdError::MissingAtSeparator),
);
}
#[test]
fn rejects_empty_name() {
assert_eq!("@1".parse::<EmbeddingModelId>(), Err(ParseEmbeddingModelIdError::EmptyName),);
}
#[test]
fn rejects_zero_revision() {
match "name@0".parse::<EmbeddingModelId>() {
Err(ParseEmbeddingModelIdError::InvalidRevision(s)) => assert_eq!(s, "0"),
other => panic!("expected InvalidRevision, got {other:?}"),
}
}
#[test]
fn rejects_negative_revision() {
assert!(matches!(
"name@-1".parse::<EmbeddingModelId>(),
Err(ParseEmbeddingModelIdError::InvalidRevision(_)),
));
}
#[test]
fn rejects_non_numeric_revision() {
assert!(matches!(
"name@v1".parse::<EmbeddingModelId>(),
Err(ParseEmbeddingModelIdError::InvalidRevision(_)),
));
}
#[test]
fn rejects_overflow_revision() {
let big = format!("name@{}", u64::from(u32::MAX) + 1);
assert!(matches!(
big.parse::<EmbeddingModelId>(),
Err(ParseEmbeddingModelIdError::InvalidRevision(_)),
));
}
#[test]
fn serde_uses_string_form() {
let id = EmbeddingModelId::new("bge-base-en-v1.5", 1).unwrap();
let j = serde_json::to_value(&id).unwrap();
assert_eq!(j, serde_json::Value::String("bge-base-en-v1.5@1".into()));
let back: EmbeddingModelId = serde_json::from_value(j).unwrap();
assert_eq!(id, back);
}
}