use crate::error::Error;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum DataCollection {
Allow,
Deny,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum ProviderSort {
Price,
Throughput,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Quantization {
Int4,
Int8,
Fp6,
Fp8,
Fp16,
Bf16,
Fp32,
Unknown,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct ProviderPreferences {
#[serde(skip_serializing_if = "Option::is_none")]
pub order: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub allow_fallbacks: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub require_parameters: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub data_collection: Option<DataCollection>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ignore: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub quantizations: Option<Vec<Quantization>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sort: Option<ProviderSort>,
}
impl Default for ProviderPreferences {
fn default() -> Self {
Self::new()
}
}
impl ProviderPreferences {
pub fn new() -> Self {
Self {
order: None,
allow_fallbacks: None,
require_parameters: None,
data_collection: None,
ignore: None,
quantizations: None,
sort: None,
}
}
pub fn validate(&self) -> Result<(), Error> {
if let Some(ref order) = self.order {
if order.is_empty() {
return Err(Error::ConfigError(
"Provider order list cannot be empty".to_string(),
));
}
let mut seen = std::collections::HashSet::new();
for provider in order {
if !seen.insert(provider) {
return Err(Error::ConfigError(format!(
"Duplicate provider in order list: {provider}"
)));
}
}
}
Ok(())
}
pub fn with_order(mut self, order: Vec<String>) -> Self {
self.order = Some(order);
self
}
pub fn with_allow_fallbacks(mut self, allow_fallbacks: bool) -> Self {
self.allow_fallbacks = Some(allow_fallbacks);
self
}
pub fn with_require_parameters(mut self, require_parameters: bool) -> Self {
self.require_parameters = Some(require_parameters);
self
}
pub fn with_data_collection(mut self, data_collection: DataCollection) -> Self {
self.data_collection = Some(data_collection);
self
}
pub fn with_ignore(mut self, ignore: Vec<String>) -> Self {
self.ignore = Some(ignore);
self
}
pub fn with_quantizations(mut self, quantizations: Vec<Quantization>) -> Self {
self.quantizations = Some(quantizations);
self
}
pub fn with_sort(mut self, sort: ProviderSort) -> Self {
self.sort = Some(sort);
self
}
}