unigateway 1.2.1

Lightweight, local-first LLM gateway for developers. A stable, single-binary unified entry point for all your AI tools and models.
use super::{GatewayApiKey, GatewayState, ProviderEntry, ServiceProvider, default_round_robin};

fn to_service_provider(p: ProviderEntry) -> ServiceProvider {
    ServiceProvider {
        name: p.name,
        provider_type: p.provider_type,
        endpoint_id: if p.endpoint_id.is_empty() {
            None
        } else {
            Some(p.endpoint_id)
        },
        base_url: if p.base_url.is_empty() {
            None
        } else {
            Some(p.base_url)
        },
        api_key: Some(p.api_key),
        default_model: if p.default_model.is_empty() {
            None
        } else {
            Some(p.default_model)
        },
        model_mapping: if p.model_mapping.is_empty() {
            None
        } else {
            Some(p.model_mapping)
        },
    }
}

impl GatewayState {
    pub async fn find_gateway_api_key(&self, raw_key: &str) -> Option<GatewayApiKey> {
        let guard = self.inner.read().await;
        let k = guard.file.api_keys.iter().find(|a| a.key == raw_key)?;
        Some(GatewayApiKey {
            key: k.key.clone(),
            service_id: k.service_id.clone(),
            quota_limit: k.quota_limit,
            used_quota: k.used_quota,
            is_active: if k.is_active { 1 } else { 0 },
            qps_limit: k.qps_limit,
            concurrency_limit: k.concurrency_limit,
        })
    }

    pub async fn get_routing_strategy(&self, service_id: &str) -> String {
        let guard = self.inner.read().await;
        guard
            .file
            .services
            .iter()
            .find(|s| s.id == service_id)
            .map(|s| s.routing_strategy.clone())
            .unwrap_or_else(default_round_robin)
    }

    pub async fn select_all_providers_for_service(
        &self,
        service_id: &str,
        protocol: &str,
    ) -> Vec<ServiceProvider> {
        let guard = self.inner.read().await;
        let mut binding_priorities: Vec<(String, i64)> = guard
            .file
            .bindings
            .iter()
            .filter(|b| b.service_id == service_id)
            .map(|b| (b.provider_name.clone(), b.priority))
            .collect();
        binding_priorities.sort_by_key(|(_, prio)| *prio);

        let mut result = Vec::new();
        for (name, _) in &binding_priorities {
            if let Some(p) = guard
                .file
                .providers
                .iter()
                .find(|p| {
                    p.is_enabled
                        && (protocol.is_empty() || p.provider_type == protocol)
                        && p.name == *name
                })
                .cloned()
            {
                result.push(to_service_provider(p));
            }
        }
        result
    }

    pub async fn select_provider_for_service(
        &self,
        service_id: &str,
        protocol: &str,
    ) -> Option<ServiceProvider> {
        let providers: Vec<ProviderEntry> = {
            let guard = self.inner.read().await;
            let names: Vec<String> = guard
                .file
                .bindings
                .iter()
                .filter(|b| b.service_id == service_id)
                .map(|b| b.provider_name.clone())
                .collect();
            guard
                .file
                .providers
                .iter()
                .filter(|p| {
                    p.is_enabled
                        && (protocol.is_empty() || p.provider_type == protocol)
                        && names.contains(&p.name)
                })
                .cloned()
                .collect()
        };
        if providers.is_empty() {
            return None;
        }
        let bucket = format!("{}:{}", service_id, protocol);
        let mut rr = self.service_rr.lock().await;
        let idx = rr.entry(bucket).or_insert(0);
        let p = providers[*idx % providers.len()].clone();
        *idx = (*idx + 1) % providers.len();
        Some(to_service_provider(p))
    }

    pub async fn select_provider_for_service_with_hint(
        &self,
        service_id: &str,
        protocol: &str,
        hint: &str,
    ) -> Option<ServiceProvider> {
        let needle = hint.trim().to_ascii_lowercase();
        if needle.is_empty() {
            return self.select_provider_for_service(service_id, protocol).await;
        }

        let providers: Vec<ProviderEntry> = {
            let guard = self.inner.read().await;
            let names: Vec<String> = guard
                .file
                .bindings
                .iter()
                .filter(|b| b.service_id == service_id)
                .map(|b| b.provider_name.clone())
                .collect();
            guard
                .file
                .providers
                .iter()
                .filter(|p| {
                    p.is_enabled
                        && (protocol.is_empty() || p.provider_type == protocol)
                        && names.contains(&p.name)
                })
                .cloned()
                .collect()
        };

        let p = providers.into_iter().find(|p| {
            if p.name.eq_ignore_ascii_case(&needle) {
                return true;
            }
            if p.endpoint_id.eq_ignore_ascii_case(&needle) {
                return true;
            }
            p.endpoint_id
                .split(':')
                .next()
                .map(|family| family.eq_ignore_ascii_case(&needle))
                .unwrap_or(false)
        })?;

        Some(to_service_provider(p))
    }

    pub async fn increment_used_quota(&self, key: &str) {
        let mut guard = self.inner.write().await;
        if let Some(k) = guard.file.api_keys.iter_mut().find(|a| a.key == key) {
            k.used_quota += 1;
            guard.dirty = true;
        }
    }

    pub async fn record_stat(&self, endpoint: &str, _status_code: u16, _latency_ms: i64) {
        let mut guard = self.inner.write().await;
        guard.request_stats.total += 1;
        if endpoint == "/v1/chat/completions" {
            guard.request_stats.openai_total += 1;
        } else if endpoint == "/v1/messages" {
            guard.request_stats.anthropic_total += 1;
        } else if endpoint == "/v1/embeddings" {
            guard.request_stats.embeddings_total += 1;
        }
    }

    pub async fn metrics_snapshot(&self) -> (u64, u64, u64, u64) {
        let guard = self.inner.read().await;
        (
            guard.request_stats.total,
            guard.request_stats.openai_total,
            guard.request_stats.anthropic_total,
            guard.request_stats.embeddings_total,
        )
    }
}