Skip to main content

lash_core/provider/
resolver.rs

1use std::collections::BTreeMap;
2
3use super::ProviderHandle;
4
5#[derive(Clone, Debug, PartialEq, Eq)]
6pub struct ProviderBinding {
7    pub provider_id: String,
8    pub provider: ProviderHandle,
9}
10
11impl Default for ProviderBinding {
12    fn default() -> Self {
13        Self {
14            provider_id: String::new(),
15            provider: ProviderHandle::unconfigured(),
16        }
17    }
18}
19
20impl ProviderBinding {
21    pub fn new(
22        provider_id: impl Into<String>,
23        provider: ProviderHandle,
24    ) -> Result<Self, ProviderResolutionError> {
25        let provider_id = provider_id.into();
26        let requested = provider_id.trim();
27        if requested.is_empty() {
28            return Err(ProviderResolutionError::MissingProviderId);
29        }
30        let actual = provider.kind();
31        if actual != requested {
32            return Err(ProviderResolutionError::ProviderIdMismatch {
33                expected: requested.to_string(),
34                actual: actual.to_string(),
35            });
36        }
37        Ok(Self {
38            provider_id: requested.to_string(),
39            provider,
40        })
41    }
42
43    pub fn from_provider(provider: ProviderHandle) -> Self {
44        Self {
45            provider_id: provider.kind().to_string(),
46            provider,
47        }
48    }
49}
50
51#[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)]
52pub enum ProviderResolutionError {
53    #[error("session policy does not specify provider_id")]
54    MissingProviderId,
55    #[error("provider `{provider_id}` is not registered with the runtime host")]
56    UnknownProvider { provider_id: String },
57    #[error("provider resolver returned `{actual}` for requested provider `{expected}`")]
58    ProviderIdMismatch { expected: String, actual: String },
59}
60
61pub trait RuntimeProviderResolver: Send + Sync {
62    fn resolve_provider_binding(
63        &self,
64        provider_id: &str,
65    ) -> Result<ProviderBinding, ProviderResolutionError>;
66}
67
68#[derive(Clone, Debug, Default)]
69pub struct EmptyProviderResolver;
70
71impl RuntimeProviderResolver for EmptyProviderResolver {
72    fn resolve_provider_binding(
73        &self,
74        provider_id: &str,
75    ) -> Result<ProviderBinding, ProviderResolutionError> {
76        let provider_id = provider_id.trim();
77        if provider_id.is_empty() {
78            return Err(ProviderResolutionError::MissingProviderId);
79        }
80        Err(ProviderResolutionError::UnknownProvider {
81            provider_id: provider_id.to_string(),
82        })
83    }
84}
85
86#[derive(Clone, Debug)]
87pub struct SingleProviderResolver {
88    provider_id: String,
89    provider: ProviderHandle,
90}
91
92impl SingleProviderResolver {
93    pub fn new(provider: ProviderHandle) -> Self {
94        Self {
95            provider_id: provider.kind().to_string(),
96            provider,
97        }
98    }
99
100    pub fn with_provider_id(provider_id: impl Into<String>, provider: ProviderHandle) -> Self {
101        Self {
102            provider_id: provider_id.into(),
103            provider,
104        }
105    }
106}
107
108impl RuntimeProviderResolver for SingleProviderResolver {
109    fn resolve_provider_binding(
110        &self,
111        provider_id: &str,
112    ) -> Result<ProviderBinding, ProviderResolutionError> {
113        let requested = provider_id.trim();
114        if requested.is_empty() {
115            return Err(ProviderResolutionError::MissingProviderId);
116        }
117        if requested != self.provider_id {
118            return Err(ProviderResolutionError::UnknownProvider {
119                provider_id: requested.to_string(),
120            });
121        }
122        ProviderBinding::new(requested, self.provider.clone())
123    }
124}
125
126#[derive(Clone, Debug, Default)]
127pub struct MapProviderResolver {
128    providers: BTreeMap<String, ProviderHandle>,
129}
130
131impl MapProviderResolver {
132    pub fn new() -> Self {
133        Self::default()
134    }
135
136    pub fn with_provider(mut self, provider: ProviderHandle) -> Self {
137        self.providers.insert(provider.kind().to_string(), provider);
138        self
139    }
140
141    pub fn with_provider_id(
142        mut self,
143        provider_id: impl Into<String>,
144        provider: ProviderHandle,
145    ) -> Self {
146        self.providers.insert(provider_id.into(), provider);
147        self
148    }
149}
150
151impl RuntimeProviderResolver for MapProviderResolver {
152    fn resolve_provider_binding(
153        &self,
154        provider_id: &str,
155    ) -> Result<ProviderBinding, ProviderResolutionError> {
156        let requested = provider_id.trim();
157        if requested.is_empty() {
158            return Err(ProviderResolutionError::MissingProviderId);
159        }
160        let provider = self.providers.get(requested).ok_or_else(|| {
161            ProviderResolutionError::UnknownProvider {
162                provider_id: requested.to_string(),
163            }
164        })?;
165        ProviderBinding::new(requested, provider.clone())
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    fn provider(kind: &'static str) -> ProviderHandle {
174        crate::testing::TestProvider::builder()
175            .kind(kind)
176            .build()
177            .into_handle()
178    }
179
180    #[test]
181    fn map_provider_resolver_returns_registered_provider() {
182        let resolver = MapProviderResolver::new().with_provider(provider("mock"));
183
184        let resolved = resolver
185            .resolve_provider_binding("mock")
186            .expect("registered provider resolves");
187
188        assert_eq!(resolved.provider_id, "mock");
189        assert_eq!(resolved.provider.kind(), "mock");
190    }
191
192    #[test]
193    fn map_provider_resolver_reports_missing_provider_id() {
194        let resolver = MapProviderResolver::new().with_provider(provider("mock"));
195
196        let err = resolver
197            .resolve_provider_binding("  ")
198            .expect_err("empty provider id is rejected");
199
200        assert_eq!(err, ProviderResolutionError::MissingProviderId);
201    }
202
203    #[test]
204    fn map_provider_resolver_reports_unknown_provider_id() {
205        let resolver = MapProviderResolver::new().with_provider(provider("mock"));
206
207        let err = resolver
208            .resolve_provider_binding("other")
209            .expect_err("unknown provider id is rejected");
210
211        assert_eq!(
212            err,
213            ProviderResolutionError::UnknownProvider {
214                provider_id: "other".to_string()
215            }
216        );
217    }
218
219    #[test]
220    fn map_provider_resolver_reports_provider_id_mismatch() {
221        let resolver = MapProviderResolver::new().with_provider_id("recorded", provider("actual"));
222
223        let err = resolver
224            .resolve_provider_binding("recorded")
225            .expect_err("mismatched live provider is rejected");
226
227        assert_eq!(
228            err,
229            ProviderResolutionError::ProviderIdMismatch {
230                expected: "recorded".to_string(),
231                actual: "actual".to_string()
232            }
233        );
234    }
235}