1use crate::{CoreError, CoreResult, LanguageTag, ModelId, ProviderId, StrategyId};
2
3#[derive(Debug, Clone, PartialEq, Eq)]
4pub enum GatewayOperation {
5 Encode,
6 Decode,
7 Analyze,
8}
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum GatewayMessageRole {
12 System,
13 User,
14 Assistant,
15}
16
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct GatewayMessage {
19 pub role: GatewayMessageRole,
20 pub content: String,
21}
22
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct GatewayRequest {
25 pub provider: ProviderId,
26 pub model: ModelId,
27 pub language: LanguageTag,
28 pub strategy: StrategyId,
29 pub operation: GatewayOperation,
30 pub messages: Vec<GatewayMessage>,
31 pub seed: Option<u64>,
32 pub max_tokens: Option<u32>,
33}
34
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub enum GatewayFinishReason {
37 Stop,
38 Length,
39 Unknown(String),
40}
41
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub struct GatewayUsage {
44 pub prompt_tokens: u32,
45 pub completion_tokens: u32,
46 pub total_tokens: u32,
47}
48
49#[derive(Debug, Clone, PartialEq, Eq)]
50pub struct GatewayResponse {
51 pub content: String,
52 pub finish_reason: GatewayFinishReason,
53 pub usage: Option<GatewayUsage>,
54}
55
56pub trait ModelGateway: Send + Sync {
57 fn provider(&self) -> &ProviderId;
58 fn complete(&self, request: GatewayRequest) -> CoreResult<GatewayResponse>;
59}
60
61pub trait ModelGatewayRegistry: Send + Sync {
62 fn gateway(&self, provider: &ProviderId) -> Option<&dyn ModelGateway>;
63
64 fn route(&self, provider: &ProviderId, model: &ModelId) -> CoreResult<&dyn ModelGateway> {
65 self.gateway(provider)
66 .ok_or_else(|| CoreError::UnsupportedModel {
67 provider: provider.to_string(),
68 model: model.to_string(),
69 })
70 }
71}
72
73#[cfg(test)]
74mod tests {
75 use crate::{LanguageTag, ModelId, ProviderId, StrategyId};
76
77 use super::{
78 GatewayFinishReason, GatewayMessage, GatewayMessageRole, GatewayOperation, GatewayRequest,
79 GatewayResponse, ModelGateway, ModelGatewayRegistry,
80 };
81
82 struct TestGateway {
83 provider: ProviderId,
84 }
85
86 impl ModelGateway for TestGateway {
87 fn provider(&self) -> &ProviderId {
88 &self.provider
89 }
90
91 fn complete(&self, request: GatewayRequest) -> crate::CoreResult<GatewayResponse> {
92 Ok(GatewayResponse {
93 content: format!(
94 "{}:{}",
95 request.operation_name_for_test(),
96 request.messages.len()
97 ),
98 finish_reason: GatewayFinishReason::Stop,
99 usage: None,
100 })
101 }
102 }
103
104 struct TestRegistry {
105 gateways: Vec<TestGateway>,
106 }
107
108 impl ModelGatewayRegistry for TestRegistry {
109 fn gateway(&self, provider: &ProviderId) -> Option<&dyn ModelGateway> {
110 self.gateways
111 .iter()
112 .find(|gateway| gateway.provider() == provider)
113 .map(|gateway| gateway as &dyn ModelGateway)
114 }
115 }
116
117 impl GatewayRequest {
118 fn operation_name_for_test(&self) -> &'static str {
119 match self.operation {
120 GatewayOperation::Encode => "encode",
121 GatewayOperation::Decode => "decode",
122 GatewayOperation::Analyze => "analyze",
123 }
124 }
125 }
126
127 #[test]
128 fn registry_routes_to_registered_provider() {
129 let provider = ProviderId::new("stub").expect("valid provider id");
130 let registry = TestRegistry {
131 gateways: vec![TestGateway {
132 provider: provider.clone(),
133 }],
134 };
135 let model = ModelId::new("model-a").expect("valid model id");
136
137 let gateway = registry
138 .route(&provider, &model)
139 .expect("gateway should be found");
140 let response = gateway
141 .complete(sample_request(provider, model, GatewayOperation::Encode))
142 .expect("completion should succeed");
143
144 assert_eq!(response.content, "encode:1");
145 }
146
147 #[test]
148 fn registry_rejects_unknown_provider() {
149 let registry = TestRegistry {
150 gateways: Vec::new(),
151 };
152 let provider = ProviderId::new("missing").expect("valid provider id");
153 let model = ModelId::new("model-a").expect("valid model id");
154
155 let result = registry.route(&provider, &model);
156 match result {
157 Ok(_) => panic!("route should fail"),
158 Err(error) => assert!(error.to_string().contains("model is not supported")),
159 }
160 }
161
162 fn sample_request(
163 provider: ProviderId,
164 model: ModelId,
165 operation: GatewayOperation,
166 ) -> GatewayRequest {
167 GatewayRequest {
168 provider,
169 model,
170 language: LanguageTag::new("fa").expect("valid language tag"),
171 strategy: StrategyId::new("symbolic").expect("valid strategy"),
172 operation,
173 messages: vec![GatewayMessage {
174 role: GatewayMessageRole::User,
175 content: "sample".to_string(),
176 }],
177 seed: Some(7),
178 max_tokens: Some(32),
179 }
180 }
181}