awaken-server 0.6.0

Multi-protocol HTTP server with SSE, mailbox, and protocol adapters for Awaken
Documentation
use std::collections::HashMap;
use std::sync::Arc;

use awaken_runtime::registry::memory::{
    MapAgentSpecRegistry, MapModelRegistry, MapProviderRegistry,
};
use awaken_runtime::registry::{
    AgentSpecRegistry, ModelCapabilityPatch, RegistrySet, ToolRegistry,
};
use awaken_server_contract::{AgentSpec, ModelPoolSpec, ModelSpec, ProviderSpec};
use sha2::{Digest, Sha256};

use super::{
    AgentSpecRegistryWithDiscovery, ConfigRuntimeError, ConfigRuntimeManager, ProviderExecutorCache,
};

pub(super) struct RegistryCompileInput<'a> {
    pub providers: &'a [ProviderSpec],
    pub models: &'a [ModelSpec],
    pub pools: &'a [ModelPoolSpec],
    pub agents: &'a [AgentSpec],
    pub tool_specs: &'a [awaken_server_contract::ToolSpec],
    pub dynamic_tools: Option<Arc<dyn ToolRegistry>>,
    pub discovered_agents: Option<Arc<dyn AgentSpecRegistry>>,
    pub provider_capabilities: &'a HashMap<String, HashMap<String, ModelCapabilityPatch>>,
}

impl ConfigRuntimeManager {
    pub(super) fn compile_registry_set(
        &self,
        input: RegistryCompileInput<'_>,
    ) -> Result<(RegistrySet, ProviderExecutorCache), ConfigRuntimeError> {
        let RegistryCompileInput {
            providers,
            models,
            pools,
            agents,
            tool_specs,
            dynamic_tools,
            discovered_agents,
            provider_capabilities,
        } = input;
        let mut provider_registry = MapProviderRegistry::new();
        let mut next_cache: ProviderExecutorCache = HashMap::with_capacity(providers.len());
        let prior_cache = self.provider_cache.lock().executor_snapshot();
        for provider in providers {
            let executor = match prior_cache.get(&provider.id) {
                Some((cached_spec, cached_executor)) if cached_spec == provider => {
                    Arc::clone(cached_executor)
                }
                _ => self.provider_factory.build(provider)?,
            };
            next_cache.insert(
                provider.id.clone(),
                (provider.clone(), Arc::clone(&executor)),
            );
            provider_registry
                .register_provider_with_signature_and_capability_source(
                    provider.id.clone(),
                    executor,
                    provider_definition_signature(provider),
                    provider.adapter.clone(),
                )
                .map_err(|error| ConfigRuntimeError::InvalidConfig(error.to_string()))?;
            if let Some(capabilities) = provider_capabilities.get(&provider.id) {
                provider_registry.replace_provider_model_capability_snapshot(
                    provider.id.clone(),
                    capabilities.clone(),
                );
            }
        }

        let mut model_registry = MapModelRegistry::new();
        for model in models {
            model_registry
                .register_model(model.clone())
                .map_err(|error| ConfigRuntimeError::InvalidConfig(error.to_string()))?;
        }
        for pool in pools {
            model_registry
                .register_model_pool(pool.clone())
                .map_err(|error| ConfigRuntimeError::InvalidConfig(error.to_string()))?;
        }

        let mut local_agents = MapAgentSpecRegistry::new();
        for agent in agents {
            local_agents
                .register_spec(agent.clone())
                .map_err(|error| ConfigRuntimeError::InvalidConfig(error.to_string()))?;
        }

        let local_agents: Arc<dyn AgentSpecRegistry> = Arc::new(local_agents);
        let discovered_agents = discovered_agents.or_else(|| self.discovered_agents.clone());
        let agents = match discovered_agents {
            Some(fallback) => Arc::new(AgentSpecRegistryWithDiscovery::new(local_agents, fallback))
                as Arc<dyn AgentSpecRegistry>,
            None => local_agents,
        };

        let overrides: HashMap<String, String> = tool_specs
            .iter()
            .filter_map(|spec| {
                let live = self.tools.get_tool(&spec.id)?;
                if live.descriptor().description != spec.description {
                    Some((spec.id.clone(), spec.description.clone()))
                } else {
                    None
                }
            })
            .collect();
        let tools = self.compose_tool_registry(dynamic_tools, overrides)?;

        Ok((
            RegistrySet {
                agents,
                tools,
                models: Arc::new(model_registry),
                providers: Arc::new(provider_registry),
                plugins: Arc::clone(&self.plugins),
                backends: Arc::clone(&self.backends),
            },
            next_cache,
        ))
    }
}

pub(super) fn provider_definition_signature(provider: &ProviderSpec) -> String {
    let options =
        serde_json::to_string(&provider.adapter_options).unwrap_or_else(|_| "<options>".into());
    format!(
        "adapter={};base_url={:?};timeout={};credential={};options={}",
        provider.adapter,
        provider.base_url,
        provider.timeout_secs,
        provider_credential_signature(provider),
        options
    )
}

fn provider_credential_signature(provider: &ProviderSpec) -> String {
    let kind = provider
        .adapter_options
        .get("credentials_kind")
        .and_then(|value| value.as_str())
        .unwrap_or("bearer");
    let fingerprint = provider
        .api_key
        .as_ref()
        .filter(|key| !key.is_empty())
        .map(|key| {
            let digest = Sha256::digest(key.expose_secret().as_bytes());
            format!("sha256:{digest:x}")
        })
        .unwrap_or_else(|| "none".to_string());
    format!("kind={kind};material={fingerprint}")
}

#[cfg(test)]
mod tests {
    use std::collections::BTreeMap;

    use awaken_server_contract::ProviderSpec;

    use super::provider_definition_signature;

    fn provider(api_key: Option<&str>) -> ProviderSpec {
        ProviderSpec {
            id: "provider-a".into(),
            adapter: "openai".into(),
            api_key: api_key.map(Into::into),
            base_url: Some("https://example.invalid/v1".into()),
            timeout_secs: 30,
            adapter_options: BTreeMap::new(),
        }
    }

    #[test]
    fn provider_signature_changes_when_credential_material_changes() {
        let first = provider_definition_signature(&provider(Some("credential-one")));
        let second = provider_definition_signature(&provider(Some("credential-two")));

        assert_ne!(first, second);
        assert!(!first.contains("credential-one"));
        assert!(!second.contains("credential-two"));
    }
}