Skip to main content

lash_core/provider/
resolver.rs

1use std::collections::BTreeMap;
2
3use super::ProviderHandle;
4
5#[derive(Clone, Debug, Default, PartialEq, Eq)]
6pub struct ProviderBinding {
7    pub provider_id: String,
8    pub provider: ProviderHandle,
9}
10
11impl ProviderBinding {
12    pub fn new(
13        provider_id: impl Into<String>,
14        provider: ProviderHandle,
15    ) -> Result<Self, ProviderResolutionError> {
16        let provider_id = provider_id.into();
17        let requested = provider_id.trim();
18        if requested.is_empty() {
19            return Err(ProviderResolutionError::MissingProviderId);
20        }
21        let actual = provider.kind();
22        if actual != requested {
23            return Err(ProviderResolutionError::ProviderIdMismatch {
24                expected: requested.to_string(),
25                actual: actual.to_string(),
26            });
27        }
28        Ok(Self {
29            provider_id: requested.to_string(),
30            provider,
31        })
32    }
33
34    pub fn from_provider(provider: ProviderHandle) -> Self {
35        Self {
36            provider_id: provider.kind().to_string(),
37            provider,
38        }
39    }
40}
41
42#[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)]
43pub enum ProviderResolutionError {
44    #[error("session policy does not specify provider_id")]
45    MissingProviderId,
46    #[error("provider `{provider_id}` is not registered with the runtime host")]
47    UnknownProvider { provider_id: String },
48    #[error("provider resolver returned `{actual}` for requested provider `{expected}`")]
49    ProviderIdMismatch { expected: String, actual: String },
50}
51
52pub trait RuntimeProviderResolver: Send + Sync {
53    fn resolve_provider_binding(
54        &self,
55        provider_id: &str,
56    ) -> Result<ProviderBinding, ProviderResolutionError>;
57}
58
59#[derive(Clone, Debug, Default)]
60pub struct EmptyProviderResolver;
61
62impl RuntimeProviderResolver for EmptyProviderResolver {
63    fn resolve_provider_binding(
64        &self,
65        provider_id: &str,
66    ) -> Result<ProviderBinding, ProviderResolutionError> {
67        let provider_id = provider_id.trim();
68        if provider_id.is_empty() {
69            return Err(ProviderResolutionError::MissingProviderId);
70        }
71        Err(ProviderResolutionError::UnknownProvider {
72            provider_id: provider_id.to_string(),
73        })
74    }
75}
76
77#[derive(Clone, Debug)]
78pub struct SingleProviderResolver {
79    provider_id: String,
80    provider: ProviderHandle,
81}
82
83impl SingleProviderResolver {
84    pub fn new(provider: ProviderHandle) -> Self {
85        Self {
86            provider_id: provider.kind().to_string(),
87            provider,
88        }
89    }
90
91    pub fn with_provider_id(provider_id: impl Into<String>, provider: ProviderHandle) -> Self {
92        Self {
93            provider_id: provider_id.into(),
94            provider,
95        }
96    }
97}
98
99impl RuntimeProviderResolver for SingleProviderResolver {
100    fn resolve_provider_binding(
101        &self,
102        provider_id: &str,
103    ) -> Result<ProviderBinding, ProviderResolutionError> {
104        let requested = provider_id.trim();
105        if requested.is_empty() {
106            return Err(ProviderResolutionError::MissingProviderId);
107        }
108        if requested != self.provider_id {
109            return Err(ProviderResolutionError::UnknownProvider {
110                provider_id: requested.to_string(),
111            });
112        }
113        ProviderBinding::new(requested, self.provider.clone())
114    }
115}
116
117#[derive(Clone, Debug, Default)]
118pub struct MapProviderResolver {
119    providers: BTreeMap<String, ProviderHandle>,
120}
121
122impl MapProviderResolver {
123    pub fn new() -> Self {
124        Self::default()
125    }
126
127    pub fn with_provider(mut self, provider: ProviderHandle) -> Self {
128        self.providers.insert(provider.kind().to_string(), provider);
129        self
130    }
131
132    pub fn with_provider_id(
133        mut self,
134        provider_id: impl Into<String>,
135        provider: ProviderHandle,
136    ) -> Self {
137        self.providers.insert(provider_id.into(), provider);
138        self
139    }
140}
141
142impl RuntimeProviderResolver for MapProviderResolver {
143    fn resolve_provider_binding(
144        &self,
145        provider_id: &str,
146    ) -> Result<ProviderBinding, ProviderResolutionError> {
147        let requested = provider_id.trim();
148        if requested.is_empty() {
149            return Err(ProviderResolutionError::MissingProviderId);
150        }
151        let provider = self.providers.get(requested).ok_or_else(|| {
152            ProviderResolutionError::UnknownProvider {
153                provider_id: requested.to_string(),
154            }
155        })?;
156        ProviderBinding::new(requested, provider.clone())
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    fn provider(kind: &'static str) -> ProviderHandle {
165        crate::testing::TestProvider::builder()
166            .kind(kind)
167            .build()
168            .into_handle()
169    }
170
171    #[test]
172    fn map_provider_resolver_returns_registered_provider() {
173        let resolver = MapProviderResolver::new().with_provider(provider("mock"));
174
175        let resolved = resolver
176            .resolve_provider_binding("mock")
177            .expect("registered provider resolves");
178
179        assert_eq!(resolved.provider_id, "mock");
180        assert_eq!(resolved.provider.kind(), "mock");
181    }
182
183    #[test]
184    fn map_provider_resolver_reports_missing_provider_id() {
185        let resolver = MapProviderResolver::new().with_provider(provider("mock"));
186
187        let err = resolver
188            .resolve_provider_binding("  ")
189            .expect_err("empty provider id is rejected");
190
191        assert_eq!(err, ProviderResolutionError::MissingProviderId);
192    }
193
194    #[test]
195    fn map_provider_resolver_reports_unknown_provider_id() {
196        let resolver = MapProviderResolver::new().with_provider(provider("mock"));
197
198        let err = resolver
199            .resolve_provider_binding("other")
200            .expect_err("unknown provider id is rejected");
201
202        assert_eq!(
203            err,
204            ProviderResolutionError::UnknownProvider {
205                provider_id: "other".to_string()
206            }
207        );
208    }
209
210    #[test]
211    fn map_provider_resolver_reports_provider_id_mismatch() {
212        let resolver = MapProviderResolver::new().with_provider_id("recorded", provider("actual"));
213
214        let err = resolver
215            .resolve_provider_binding("recorded")
216            .expect_err("mismatched live provider is rejected");
217
218        assert_eq!(
219            err,
220            ProviderResolutionError::ProviderIdMismatch {
221                expected: "recorded".to_string(),
222                actual: "actual".to_string()
223            }
224        );
225    }
226}