linguasteg-core 0.2.0

Core domain contracts and pipeline abstractions for LinguaSteg
Documentation
use crate::{CoreError, CoreResult, LanguageTag, ModelId, ProviderId, StrategyId};

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum GatewayOperation {
    Encode,
    Decode,
    Analyze,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GatewayMessageRole {
    System,
    User,
    Assistant,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct GatewayMessage {
    pub role: GatewayMessageRole,
    pub content: String,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct GatewayRequest {
    pub provider: ProviderId,
    pub model: ModelId,
    pub language: LanguageTag,
    pub strategy: StrategyId,
    pub operation: GatewayOperation,
    pub messages: Vec<GatewayMessage>,
    pub seed: Option<u64>,
    pub max_tokens: Option<u32>,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum GatewayFinishReason {
    Stop,
    Length,
    Unknown(String),
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct GatewayUsage {
    pub prompt_tokens: u32,
    pub completion_tokens: u32,
    pub total_tokens: u32,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct GatewayResponse {
    pub content: String,
    pub finish_reason: GatewayFinishReason,
    pub usage: Option<GatewayUsage>,
}

pub trait ModelGateway: Send + Sync {
    fn provider(&self) -> &ProviderId;
    fn complete(&self, request: GatewayRequest) -> CoreResult<GatewayResponse>;
}

pub trait ModelGatewayRegistry: Send + Sync {
    fn gateway(&self, provider: &ProviderId) -> Option<&dyn ModelGateway>;

    fn route(&self, provider: &ProviderId, model: &ModelId) -> CoreResult<&dyn ModelGateway> {
        self.gateway(provider)
            .ok_or_else(|| CoreError::UnsupportedModel {
                provider: provider.to_string(),
                model: model.to_string(),
            })
    }
}

#[cfg(test)]
mod tests {
    use crate::{LanguageTag, ModelId, ProviderId, StrategyId};

    use super::{
        GatewayFinishReason, GatewayMessage, GatewayMessageRole, GatewayOperation, GatewayRequest,
        GatewayResponse, ModelGateway, ModelGatewayRegistry,
    };

    struct TestGateway {
        provider: ProviderId,
    }

    impl ModelGateway for TestGateway {
        fn provider(&self) -> &ProviderId {
            &self.provider
        }

        fn complete(&self, request: GatewayRequest) -> crate::CoreResult<GatewayResponse> {
            Ok(GatewayResponse {
                content: format!(
                    "{}:{}",
                    request.operation_name_for_test(),
                    request.messages.len()
                ),
                finish_reason: GatewayFinishReason::Stop,
                usage: None,
            })
        }
    }

    struct TestRegistry {
        gateways: Vec<TestGateway>,
    }

    impl ModelGatewayRegistry for TestRegistry {
        fn gateway(&self, provider: &ProviderId) -> Option<&dyn ModelGateway> {
            self.gateways
                .iter()
                .find(|gateway| gateway.provider() == provider)
                .map(|gateway| gateway as &dyn ModelGateway)
        }
    }

    impl GatewayRequest {
        fn operation_name_for_test(&self) -> &'static str {
            match self.operation {
                GatewayOperation::Encode => "encode",
                GatewayOperation::Decode => "decode",
                GatewayOperation::Analyze => "analyze",
            }
        }
    }

    #[test]
    fn registry_routes_to_registered_provider() {
        let provider = ProviderId::new("stub").expect("valid provider id");
        let registry = TestRegistry {
            gateways: vec![TestGateway {
                provider: provider.clone(),
            }],
        };
        let model = ModelId::new("model-a").expect("valid model id");

        let gateway = registry
            .route(&provider, &model)
            .expect("gateway should be found");
        let response = gateway
            .complete(sample_request(provider, model, GatewayOperation::Encode))
            .expect("completion should succeed");

        assert_eq!(response.content, "encode:1");
    }

    #[test]
    fn registry_rejects_unknown_provider() {
        let registry = TestRegistry {
            gateways: Vec::new(),
        };
        let provider = ProviderId::new("missing").expect("valid provider id");
        let model = ModelId::new("model-a").expect("valid model id");

        let result = registry.route(&provider, &model);
        match result {
            Ok(_) => panic!("route should fail"),
            Err(error) => assert!(error.to_string().contains("model is not supported")),
        }
    }

    fn sample_request(
        provider: ProviderId,
        model: ModelId,
        operation: GatewayOperation,
    ) -> GatewayRequest {
        GatewayRequest {
            provider,
            model,
            language: LanguageTag::new("fa").expect("valid language tag"),
            strategy: StrategyId::new("symbolic").expect("valid strategy"),
            operation,
            messages: vec![GatewayMessage {
                role: GatewayMessageRole::User,
                content: "sample".to_string(),
            }],
            seed: Some(7),
            max_tokens: Some(32),
        }
    }
}