use std::path::Path;
use crate::config::Config;
use crate::embedding::EmbeddingEngine;
use crate::errors::Error;
use crate::sqlite::Database;
pub const MAX_INPUT_LENGTH: usize = 100_000;
pub const MAX_SEARCH_LIMIT: usize = 10_000;
pub(crate) fn validate_limit(limit: usize) -> Result<(), Error> {
if limit == 0 {
return Err(Error::InvalidInput(
"Limit must be greater than 0".to_string(),
));
}
if limit > MAX_SEARCH_LIMIT {
return Err(Error::InvalidInput(format!(
"Limit {} exceeds maximum allowed ({})",
limit, MAX_SEARCH_LIMIT
)));
}
Ok(())
}
pub struct MemoryStore {
pub(crate) db: Database,
pub(crate) embedder: Option<EmbeddingEngine>,
pub(crate) model_id: String,
pub(crate) config: Config,
}
impl MemoryStore {
pub fn new(db_path: &Path, model_id: &str, config: Config) -> Result<Self, Error> {
use std::path::Component;
for component in db_path.components() {
if matches!(component, Component::ParentDir) {
return Err(Error::Config(
"Invalid database path: contains '..' which may escape the intended directory"
.to_string(),
));
}
}
let db_real_path = if db_path.exists() {
std::fs::canonicalize(db_path).map_err(|e| {
Error::Config(format!(
"Invalid database path: cannot canonicalize existing path: {}",
e
))
})?
} else {
let parent = db_path.parent().ok_or_else(|| {
Error::Config("Invalid database path: no parent directory".to_string())
})?;
let canonical_parent = std::fs::canonicalize(parent).map_err(|e| {
Error::Config(format!(
"Invalid database path: parent directory not accessible: {}",
e
))
})?;
let filename = db_path
.file_name()
.ok_or_else(|| Error::Config("Invalid database path: no filename".to_string()))?;
canonical_parent.join(filename)
};
let db = Database::open(&db_real_path)?;
Ok(MemoryStore {
db,
embedder: None,
model_id: model_id.to_string(),
config,
})
}
pub(crate) fn embedder(&mut self) -> Result<&mut EmbeddingEngine, Error> {
if self.embedder.is_none() {
self.embedder = Some(EmbeddingEngine::new(&self.model_id)?);
}
Ok(self.embedder.as_mut().unwrap())
}
pub(crate) fn validate_input_length(text: &str) -> Result<(), Error> {
if text.trim().is_empty() {
return Err(Error::EmptyInput);
}
if text.len() > MAX_INPUT_LENGTH {
return Err(Error::InputTooLong {
max_length: MAX_INPUT_LENGTH,
actual_length: text.len(),
});
}
Ok(())
}
}