yomo 0.6.4

A QUIC-based runtime for AI-LLM tool routing and serverless execution
Documentation
use std::collections::HashMap;
use std::sync::Arc;

use crate::model_api_provider::providers;
use crate::model_api_provider::selection::{SelectionError, SelectionResult, SelectionStrategy};
use crate::model_api_provider::ModelApiProvider;
use crate::serve_config::{ConfigError, ModelApiConfig, ModelApiEndpointConfig, ProviderConfig};

#[derive(Clone)]
pub struct ProviderEntry {
    pub model_id: String,
    pub label: Option<String>,
    pub provider: Arc<dyn ModelApiProvider>,
}

pub struct ProviderRegistry<M> {
    providers: HashMap<String, HashMap<String, ProviderEntry>>,
    endpoints: HashMap<String, ModelApiEndpointConfig>,
    strategy: Arc<dyn SelectionStrategy<M>>,
}

impl<M> ProviderRegistry<M> {
    pub fn from_config(
        config: &ModelApiConfig,
        strategy: Arc<dyn SelectionStrategy<M>>,
    ) -> Result<Self, ConfigError> {
        let providers = build_providers(&config.providers, &config.endpoints)?;
        let endpoints = config
            .endpoints
            .iter()
            .map(|endpoint| (endpoint.path.clone(), endpoint.clone()))
            .collect();
        Ok(Self {
            providers,
            endpoints,
            strategy,
        })
    }

    pub fn select(
        &self,
        endpoint: &str,
        model_id: Option<&str>,
        metadata: &M,
    ) -> Result<ProviderEntry, SelectionError> {
        let selected = self
            .strategy
            .select(endpoint, model_id, metadata)
            .map_err(|err| err)?;
        let provider = self
            .providers
            .get(endpoint)
            .and_then(|endpoint_models| {
                endpoint_models
                    .values()
                    .find(|provider| {
                        provider.model_id.to_ascii_lowercase()
                            == selected.model_id.to_ascii_lowercase()
                    })
                    .cloned()
            })
            .ok_or(SelectionError::ModelNotSupported)?;
        Ok(provider)
    }

    pub fn endpoint(&self, path: &str) -> Option<&ModelApiEndpointConfig> {
        self.endpoints.get(path)
    }
}

pub struct ByEndpointModel {
    endpoints: HashMap<String, ModelApiEndpointConfig>,
}

impl ByEndpointModel {
    pub fn new(endpoints: HashMap<String, ModelApiEndpointConfig>) -> Self {
        Self { endpoints }
    }
}

impl<M> SelectionStrategy<M> for ByEndpointModel {
    fn select(
        &self,
        endpoint: &str,
        model_id: Option<&str>,
        _metadata: &M,
    ) -> Result<SelectionResult, SelectionError> {
        if let Some(model) = model_id.filter(|value| !value.trim().is_empty()) {
            return Ok(SelectionResult {
                model_id: model.to_string(),
            });
        }
        let endpoint = self.endpoints.get(endpoint);
        if let Some(endpoint) = endpoint {
            if let Some(default_model) = &endpoint.default_model {
                if !default_model.trim().is_empty() {
                    return Ok(SelectionResult {
                        model_id: default_model.clone(),
                    });
                }
            }
        }
        Err(SelectionError::ModelNotSupported)
    }
}

fn build_providers(
    providers: &[ProviderConfig],
    endpoints: &[ModelApiEndpointConfig],
) -> Result<HashMap<String, HashMap<String, ProviderEntry>>, ConfigError> {
    let mut provider_map: HashMap<String, &ProviderConfig> = HashMap::new();
    for item in providers {
        provider_map.insert(item.model_id.clone(), item);
    }

    let mut registry: HashMap<String, HashMap<String, ProviderEntry>> = HashMap::new();
    for endpoint in endpoints {
        let mut endpoint_models: HashMap<String, ProviderEntry> = HashMap::new();
        let mut model_ids = endpoint.models.clone();
        if let Some(default_model) = &endpoint.default_model {
            if !model_ids.iter().any(|model| model == default_model) {
                model_ids.push(default_model.clone());
            }
        }
        for model_id in model_ids.iter() {
            let provider_config = provider_map.get(model_id).ok_or_else(|| {
                ConfigError::InvalidProvider(format!(
                    "model_api endpoint model not found: {}",
                    model_id
                ))
            })?;
            let provider = build_provider(provider_config, endpoint.path.as_str())?;
            let entry = ProviderEntry {
                model_id: provider_config.model_id.clone(),
                label: provider_config.label.clone(),
                provider,
            };
            endpoint_models.insert(model_id.clone(), entry);
        }
        registry.insert(endpoint.path.clone(), endpoint_models);
    }

    Ok(registry)
}

fn build_provider(
    provider: &ProviderConfig,
    endpoint_path: &str,
) -> Result<Arc<dyn ModelApiProvider>, ConfigError> {
    match endpoint_path {
        "/messages" => providers::messages::build_client(provider),
        "/responses" => providers::responses::build_client(provider),
        "/embeddings" => providers::passthrough::build_client(provider),
        "/rerank" => providers::passthrough::build_client(provider),
        "/audio/speech" => providers::passthrough::build_client(provider),
        "/audio/transcriptions" => providers::passthrough::build_client(provider),
        "/images/generations" => providers::passthrough::build_client(provider),
        "/images/edits" => providers::passthrough::build_client(provider),
        "/models/:generateContent" => providers::generate_content::build_client(provider),
        other => Err(ConfigError::InvalidProvider(format!(
            "unknown model_api endpoint: {}",
            other
        ))),
    }
}