use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
use cqs::embedder::ModelConfig;
use cqs::parser::{CallSite, ChunkTypeRefs, FunctionCalls};
use cqs::{Chunk, Embedding, Store};
#[derive(Clone, Default)]
pub(super) struct RelationshipData {
pub type_refs: HashMap<PathBuf, Vec<ChunkTypeRefs>>,
pub function_calls: HashMap<PathBuf, Vec<FunctionCalls>>,
pub chunk_calls: Vec<(String, CallSite)>,
}
pub(super) struct ParsedBatch {
pub chunks: Vec<Chunk>,
pub relationships: RelationshipData,
pub file_mtimes: HashMap<PathBuf, i64>,
}
pub(super) struct EmbeddedBatch {
pub chunk_embeddings: Vec<(Chunk, Embedding)>,
pub relationships: RelationshipData,
pub cached_count: usize,
pub file_mtimes: HashMap<PathBuf, i64>,
}
pub(crate) struct PipelineStats {
pub total_embedded: usize,
pub total_cached: usize,
pub gpu_failures: usize,
pub parse_errors: usize,
pub total_type_edges: usize,
pub total_calls: usize,
}
pub(super) struct PreparedEmbedding {
pub cached: Vec<(Chunk, Embedding)>,
pub to_embed: Vec<Chunk>,
pub texts: Vec<String>,
pub relationships: RelationshipData,
pub file_mtimes: HashMap<PathBuf, i64>,
}
pub(super) struct EmbedStageContext {
pub store: Arc<Store>,
pub embedded_count: Arc<AtomicUsize>,
pub model_config: ModelConfig,
pub global_cache: Option<Arc<cqs::cache::EmbeddingCache>>,
}
pub(super) fn file_batch_size() -> usize {
static SIZE: std::sync::OnceLock<usize> = std::sync::OnceLock::new();
*SIZE.get_or_init(|| match std::env::var("CQS_FILE_BATCH_SIZE") {
Ok(val) => match val.parse::<usize>() {
Ok(n) if n > 0 => {
tracing::info!(batch_size = n, "CQS_FILE_BATCH_SIZE override");
n
}
_ => {
tracing::warn!(value = %val, "Invalid CQS_FILE_BATCH_SIZE, using default 5000");
5_000
}
},
Err(_) => 5_000,
})
}
pub(super) fn parse_channel_depth() -> usize {
static DEPTH: std::sync::OnceLock<usize> = std::sync::OnceLock::new();
*DEPTH.get_or_init(|| match std::env::var("CQS_PARSE_CHANNEL_DEPTH") {
Ok(val) => match val.parse::<usize>() {
Ok(n) if n > 0 => {
tracing::info!(depth = n, "CQS_PARSE_CHANNEL_DEPTH override");
n
}
_ => {
tracing::warn!(value = %val, "Invalid CQS_PARSE_CHANNEL_DEPTH, using default 512");
512
}
},
Err(_) => 512,
})
}
pub(super) fn embed_channel_depth() -> usize {
static DEPTH: std::sync::OnceLock<usize> = std::sync::OnceLock::new();
*DEPTH.get_or_init(|| match std::env::var("CQS_EMBED_CHANNEL_DEPTH") {
Ok(val) => match val.parse::<usize>() {
Ok(n) if n > 0 => {
tracing::info!(depth = n, "CQS_EMBED_CHANNEL_DEPTH override");
n
}
_ => {
tracing::warn!(value = %val, "Invalid CQS_EMBED_CHANNEL_DEPTH, using default 64");
64
}
},
Err(_) => 64,
})
}
pub(crate) fn embed_batch_size() -> usize {
match std::env::var("CQS_EMBED_BATCH_SIZE") {
Ok(val) => match val.parse::<usize>() {
Ok(size) if size > 0 => {
tracing::info!(batch_size = size, "CQS_EMBED_BATCH_SIZE override");
size
}
_ => {
tracing::warn!(
value = %val,
"Invalid CQS_EMBED_BATCH_SIZE, using default 64"
);
64
}
},
Err(_) => 64,
}
}