Skip to main content

linguasteg_core/
gateway.rs

1use crate::{CoreError, CoreResult, LanguageTag, ModelId, ProviderId, StrategyId};
2
3#[derive(Debug, Clone, PartialEq, Eq)]
4pub enum GatewayOperation {
5    Encode,
6    Decode,
7    Analyze,
8}
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum GatewayMessageRole {
12    System,
13    User,
14    Assistant,
15}
16
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct GatewayMessage {
19    pub role: GatewayMessageRole,
20    pub content: String,
21}
22
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct GatewayRequest {
25    pub provider: ProviderId,
26    pub model: ModelId,
27    pub language: LanguageTag,
28    pub strategy: StrategyId,
29    pub operation: GatewayOperation,
30    pub messages: Vec<GatewayMessage>,
31    pub seed: Option<u64>,
32    pub max_tokens: Option<u32>,
33}
34
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub enum GatewayFinishReason {
37    Stop,
38    Length,
39    Unknown(String),
40}
41
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub struct GatewayUsage {
44    pub prompt_tokens: u32,
45    pub completion_tokens: u32,
46    pub total_tokens: u32,
47}
48
49#[derive(Debug, Clone, PartialEq, Eq)]
50pub struct GatewayResponse {
51    pub content: String,
52    pub finish_reason: GatewayFinishReason,
53    pub usage: Option<GatewayUsage>,
54}
55
56pub trait ModelGateway: Send + Sync {
57    fn provider(&self) -> &ProviderId;
58    fn complete(&self, request: GatewayRequest) -> CoreResult<GatewayResponse>;
59}
60
61pub trait ModelGatewayRegistry: Send + Sync {
62    fn gateway(&self, provider: &ProviderId) -> Option<&dyn ModelGateway>;
63
64    fn route(&self, provider: &ProviderId, model: &ModelId) -> CoreResult<&dyn ModelGateway> {
65        self.gateway(provider)
66            .ok_or_else(|| CoreError::UnsupportedModel {
67                provider: provider.to_string(),
68                model: model.to_string(),
69            })
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use crate::{LanguageTag, ModelId, ProviderId, StrategyId};
76
77    use super::{
78        GatewayFinishReason, GatewayMessage, GatewayMessageRole, GatewayOperation, GatewayRequest,
79        GatewayResponse, ModelGateway, ModelGatewayRegistry,
80    };
81
82    struct TestGateway {
83        provider: ProviderId,
84    }
85
86    impl ModelGateway for TestGateway {
87        fn provider(&self) -> &ProviderId {
88            &self.provider
89        }
90
91        fn complete(&self, request: GatewayRequest) -> crate::CoreResult<GatewayResponse> {
92            Ok(GatewayResponse {
93                content: format!(
94                    "{}:{}",
95                    request.operation_name_for_test(),
96                    request.messages.len()
97                ),
98                finish_reason: GatewayFinishReason::Stop,
99                usage: None,
100            })
101        }
102    }
103
104    struct TestRegistry {
105        gateways: Vec<TestGateway>,
106    }
107
108    impl ModelGatewayRegistry for TestRegistry {
109        fn gateway(&self, provider: &ProviderId) -> Option<&dyn ModelGateway> {
110            self.gateways
111                .iter()
112                .find(|gateway| gateway.provider() == provider)
113                .map(|gateway| gateway as &dyn ModelGateway)
114        }
115    }
116
117    impl GatewayRequest {
118        fn operation_name_for_test(&self) -> &'static str {
119            match self.operation {
120                GatewayOperation::Encode => "encode",
121                GatewayOperation::Decode => "decode",
122                GatewayOperation::Analyze => "analyze",
123            }
124        }
125    }
126
127    #[test]
128    fn registry_routes_to_registered_provider() {
129        let provider = ProviderId::new("stub").expect("valid provider id");
130        let registry = TestRegistry {
131            gateways: vec![TestGateway {
132                provider: provider.clone(),
133            }],
134        };
135        let model = ModelId::new("model-a").expect("valid model id");
136
137        let gateway = registry
138            .route(&provider, &model)
139            .expect("gateway should be found");
140        let response = gateway
141            .complete(sample_request(provider, model, GatewayOperation::Encode))
142            .expect("completion should succeed");
143
144        assert_eq!(response.content, "encode:1");
145    }
146
147    #[test]
148    fn registry_rejects_unknown_provider() {
149        let registry = TestRegistry {
150            gateways: Vec::new(),
151        };
152        let provider = ProviderId::new("missing").expect("valid provider id");
153        let model = ModelId::new("model-a").expect("valid model id");
154
155        let result = registry.route(&provider, &model);
156        match result {
157            Ok(_) => panic!("route should fail"),
158            Err(error) => assert!(error.to_string().contains("model is not supported")),
159        }
160    }
161
162    fn sample_request(
163        provider: ProviderId,
164        model: ModelId,
165        operation: GatewayOperation,
166    ) -> GatewayRequest {
167        GatewayRequest {
168            provider,
169            model,
170            language: LanguageTag::new("fa").expect("valid language tag"),
171            strategy: StrategyId::new("symbolic").expect("valid strategy"),
172            operation,
173            messages: vec![GatewayMessage {
174                role: GatewayMessageRole::User,
175                content: "sample".to_string(),
176            }],
177            seed: Some(7),
178            max_tokens: Some(32),
179        }
180    }
181}