use crate::ensemble_llm;
use indexmap::IndexMap;
use serde::{Deserialize, Serialize};
use twox_hash::XxHash3_128;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct EnsembleBase {
pub llms: Vec<ensemble_llm::EnsembleLlmBaseWithFallbacksAndCount>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Ensemble {
pub id: String,
pub llms: Vec<ensemble_llm::EnsembleLlmWithFallbacksAndCount>,
}
impl TryFrom<EnsembleBase> for Ensemble {
type Error = String;
fn try_from(
EnsembleBase { llms: base_llms }: EnsembleBase,
) -> Result<Self, Self::Error> {
let mut llms_with_full_id: IndexMap<
String,
ensemble_llm::EnsembleLlmWithFallbacksAndCount,
> = IndexMap::with_capacity(base_llms.len());
let mut count = 0;
for base_llm in base_llms {
match base_llm.count {
0 => continue,
n => count += n,
}
let llm: ensemble_llm::EnsembleLlmWithFallbacksAndCount =
base_llm.try_into()?;
if let Some(fallbacks) = &llm.fallbacks {
if fallbacks.iter().any(|fb| fb.id == llm.inner.id) {
return Err(format!(
"Ensemble LLM cannot have identical primary and fallback IDs: {}",
llm.inner.id
));
}
for i in 0..fallbacks.len() {
for j in (i + 1)..fallbacks.len() {
if fallbacks[i].id == fallbacks[j].id {
return Err(format!(
"Ensemble LLM cannot have duplicate fallback IDs: {}",
fallbacks[i].id
));
}
}
}
}
let full_id = llm.full_id();
match llms_with_full_id.get_mut(&full_id) {
Some(existing_llm) => existing_llm.count += llm.count,
None => {
llms_with_full_id.insert(full_id, llm);
}
}
}
if count == 0 || count > 128 {
return Err(
"`ensemble.llms` must contain between 1 and 128 total LLMs"
.to_string(),
);
}
llms_with_full_id.sort_unstable_keys();
let mut hasher = XxHash3_128::with_seed(0);
for (full_id, llm) in &llms_with_full_id {
hasher.write(full_id.as_bytes());
let count_bytes = llm.count.to_le_bytes();
hasher.write(&count_bytes);
}
let id = format!("{:0>22}", base62::encode(hasher.finish_128()));
let llms = llms_with_full_id.into_values().collect::<Vec<_>>();
Ok(Ensemble { id, llms })
}
}