#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ColumnSet {
pub code_embedding: &'static str,
pub text_embedding: &'static str,
}
impl ColumnSet {
pub const OLLAMA: Self = Self {
code_embedding: "code_embedding_ollama",
text_embedding: "text_embedding_ollama",
};
pub const MXBAI: Self = Self {
code_embedding: "code_embedding_mxbai",
text_embedding: "text_embedding_mxbai",
};
pub const OPENAI: Self = Self {
code_embedding: "code_embedding",
text_embedding: "text_embedding",
};
}
pub fn select_columns_for_dimension(dimension: usize) -> anyhow::Result<ColumnSet> {
match dimension {
768 => Ok(ColumnSet::OLLAMA),
1024 => Ok(ColumnSet::MXBAI),
1536 => Ok(ColumnSet::OPENAI),
_ => Err(anyhow::anyhow!(
"Unsupported embedding dimension: {}. Supported dimensions: 768 (Ollama/Google), 1024 (mxbai-embed-large), 1536 (OpenAI)",
dimension
)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_768_dimension_selects_ollama_columns() {
let cols = select_columns_for_dimension(768).unwrap();
assert_eq!(cols, ColumnSet::OLLAMA);
assert_eq!(cols.code_embedding, "code_embedding_ollama");
assert_eq!(cols.text_embedding, "text_embedding_ollama");
}
#[test]
fn test_1536_dimension_selects_openai_columns() {
let cols = select_columns_for_dimension(1536).unwrap();
assert_eq!(cols, ColumnSet::OPENAI);
assert_eq!(cols.code_embedding, "code_embedding");
assert_eq!(cols.text_embedding, "text_embedding");
}
#[test]
fn test_1024_dimension_selects_mxbai_columns() {
let cols = select_columns_for_dimension(1024).unwrap();
assert_eq!(cols, ColumnSet::MXBAI);
assert_eq!(cols.code_embedding, "code_embedding_mxbai");
assert_eq!(cols.text_embedding, "text_embedding_mxbai");
}
#[test]
fn test_unsupported_dimension_returns_error() {
let result = select_columns_for_dimension(384);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("Unsupported"));
assert!(err_msg.contains("384"));
assert!(err_msg.contains("768"));
assert!(err_msg.contains("1024"));
assert!(err_msg.contains("1536"));
}
#[test]
fn test_zero_dimension_returns_error() {
let result = select_columns_for_dimension(0);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("Unsupported"));
assert!(err_msg.contains("0"));
}
#[test]
fn test_large_dimension_returns_error() {
let result = select_columns_for_dimension(3072);
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(msg.contains("3072"));
assert!(msg.contains("768") && msg.contains("1024") && msg.contains("1536"));
}
#[test]
fn test_column_set_constants() {
assert_eq!(ColumnSet::OLLAMA.code_embedding, "code_embedding_ollama");
assert_eq!(ColumnSet::OLLAMA.text_embedding, "text_embedding_ollama");
assert_eq!(ColumnSet::MXBAI.code_embedding, "code_embedding_mxbai");
assert_eq!(ColumnSet::MXBAI.text_embedding, "text_embedding_mxbai");
assert_eq!(ColumnSet::OPENAI.code_embedding, "code_embedding");
assert_eq!(ColumnSet::OPENAI.text_embedding, "text_embedding");
}
#[test]
fn test_column_set_equality() {
let cols1 = select_columns_for_dimension(768).unwrap();
let cols2 = select_columns_for_dimension(768).unwrap();
assert_eq!(cols1, cols2);
let cols3 = select_columns_for_dimension(1536).unwrap();
assert_ne!(cols1, cols3);
}
#[test]
fn test_error_message_helpful() {
let result = select_columns_for_dimension(512);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("512"));
assert!(err_msg.contains("768"));
assert!(err_msg.contains("1024"));
assert!(err_msg.contains("1536"));
assert!(
err_msg.contains("Ollama")
|| err_msg.contains("Google")
|| err_msg.contains("OpenAI")
|| err_msg.contains("mxbai")
);
}
}