lash_core/provider/
resolver.rs1use std::collections::BTreeMap;
2
3use super::ProviderHandle;
4
5#[derive(Clone, Debug, PartialEq, Eq)]
6pub struct ProviderBinding {
7 pub provider_id: String,
8 pub provider: ProviderHandle,
9}
10
11impl Default for ProviderBinding {
12 fn default() -> Self {
13 Self {
14 provider_id: String::new(),
15 provider: ProviderHandle::unconfigured(),
16 }
17 }
18}
19
20impl ProviderBinding {
21 pub fn new(
22 provider_id: impl Into<String>,
23 provider: ProviderHandle,
24 ) -> Result<Self, ProviderResolutionError> {
25 let provider_id = provider_id.into();
26 let requested = provider_id.trim();
27 if requested.is_empty() {
28 return Err(ProviderResolutionError::MissingProviderId);
29 }
30 let actual = provider.kind();
31 if actual != requested {
32 return Err(ProviderResolutionError::ProviderIdMismatch {
33 expected: requested.to_string(),
34 actual: actual.to_string(),
35 });
36 }
37 Ok(Self {
38 provider_id: requested.to_string(),
39 provider,
40 })
41 }
42
43 pub fn from_provider(provider: ProviderHandle) -> Self {
44 Self {
45 provider_id: provider.kind().to_string(),
46 provider,
47 }
48 }
49}
50
51#[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)]
52pub enum ProviderResolutionError {
53 #[error("session policy does not specify provider_id")]
54 MissingProviderId,
55 #[error("provider `{provider_id}` is not registered with the runtime host")]
56 UnknownProvider { provider_id: String },
57 #[error("provider resolver returned `{actual}` for requested provider `{expected}`")]
58 ProviderIdMismatch { expected: String, actual: String },
59}
60
61pub trait RuntimeProviderResolver: Send + Sync {
62 fn resolve_provider_binding(
63 &self,
64 provider_id: &str,
65 ) -> Result<ProviderBinding, ProviderResolutionError>;
66}
67
68#[derive(Clone, Debug, Default)]
69pub struct EmptyProviderResolver;
70
71impl RuntimeProviderResolver for EmptyProviderResolver {
72 fn resolve_provider_binding(
73 &self,
74 provider_id: &str,
75 ) -> Result<ProviderBinding, ProviderResolutionError> {
76 let provider_id = provider_id.trim();
77 if provider_id.is_empty() {
78 return Err(ProviderResolutionError::MissingProviderId);
79 }
80 Err(ProviderResolutionError::UnknownProvider {
81 provider_id: provider_id.to_string(),
82 })
83 }
84}
85
86#[derive(Clone, Debug)]
87pub struct SingleProviderResolver {
88 provider_id: String,
89 provider: ProviderHandle,
90}
91
92impl SingleProviderResolver {
93 pub fn new(provider: ProviderHandle) -> Self {
94 Self {
95 provider_id: provider.kind().to_string(),
96 provider,
97 }
98 }
99
100 pub fn with_provider_id(provider_id: impl Into<String>, provider: ProviderHandle) -> Self {
101 Self {
102 provider_id: provider_id.into(),
103 provider,
104 }
105 }
106}
107
108impl RuntimeProviderResolver for SingleProviderResolver {
109 fn resolve_provider_binding(
110 &self,
111 provider_id: &str,
112 ) -> Result<ProviderBinding, ProviderResolutionError> {
113 let requested = provider_id.trim();
114 if requested.is_empty() {
115 return Err(ProviderResolutionError::MissingProviderId);
116 }
117 if requested != self.provider_id {
118 return Err(ProviderResolutionError::UnknownProvider {
119 provider_id: requested.to_string(),
120 });
121 }
122 ProviderBinding::new(requested, self.provider.clone())
123 }
124}
125
126#[derive(Clone, Debug, Default)]
127pub struct MapProviderResolver {
128 providers: BTreeMap<String, ProviderHandle>,
129}
130
131impl MapProviderResolver {
132 pub fn new() -> Self {
133 Self::default()
134 }
135
136 pub fn with_provider(mut self, provider: ProviderHandle) -> Self {
137 self.providers.insert(provider.kind().to_string(), provider);
138 self
139 }
140
141 pub fn with_provider_id(
142 mut self,
143 provider_id: impl Into<String>,
144 provider: ProviderHandle,
145 ) -> Self {
146 self.providers.insert(provider_id.into(), provider);
147 self
148 }
149}
150
151impl RuntimeProviderResolver for MapProviderResolver {
152 fn resolve_provider_binding(
153 &self,
154 provider_id: &str,
155 ) -> Result<ProviderBinding, ProviderResolutionError> {
156 let requested = provider_id.trim();
157 if requested.is_empty() {
158 return Err(ProviderResolutionError::MissingProviderId);
159 }
160 let provider = self.providers.get(requested).ok_or_else(|| {
161 ProviderResolutionError::UnknownProvider {
162 provider_id: requested.to_string(),
163 }
164 })?;
165 ProviderBinding::new(requested, provider.clone())
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 fn provider(kind: &'static str) -> ProviderHandle {
174 crate::testing::TestProvider::builder()
175 .kind(kind)
176 .build()
177 .into_handle()
178 }
179
180 #[test]
181 fn map_provider_resolver_returns_registered_provider() {
182 let resolver = MapProviderResolver::new().with_provider(provider("mock"));
183
184 let resolved = resolver
185 .resolve_provider_binding("mock")
186 .expect("registered provider resolves");
187
188 assert_eq!(resolved.provider_id, "mock");
189 assert_eq!(resolved.provider.kind(), "mock");
190 }
191
192 #[test]
193 fn map_provider_resolver_reports_missing_provider_id() {
194 let resolver = MapProviderResolver::new().with_provider(provider("mock"));
195
196 let err = resolver
197 .resolve_provider_binding(" ")
198 .expect_err("empty provider id is rejected");
199
200 assert_eq!(err, ProviderResolutionError::MissingProviderId);
201 }
202
203 #[test]
204 fn map_provider_resolver_reports_unknown_provider_id() {
205 let resolver = MapProviderResolver::new().with_provider(provider("mock"));
206
207 let err = resolver
208 .resolve_provider_binding("other")
209 .expect_err("unknown provider id is rejected");
210
211 assert_eq!(
212 err,
213 ProviderResolutionError::UnknownProvider {
214 provider_id: "other".to_string()
215 }
216 );
217 }
218
219 #[test]
220 fn map_provider_resolver_reports_provider_id_mismatch() {
221 let resolver = MapProviderResolver::new().with_provider_id("recorded", provider("actual"));
222
223 let err = resolver
224 .resolve_provider_binding("recorded")
225 .expect_err("mismatched live provider is rejected");
226
227 assert_eq!(
228 err,
229 ProviderResolutionError::ProviderIdMismatch {
230 expected: "recorded".to_string(),
231 actual: "actual".to_string()
232 }
233 );
234 }
235}