lash_core/provider/
resolver.rs1use std::collections::BTreeMap;
2
3use super::ProviderHandle;
4
5#[derive(Clone, Debug, Default, PartialEq, Eq)]
6pub struct ProviderBinding {
7 pub provider_id: String,
8 pub provider: ProviderHandle,
9}
10
11impl ProviderBinding {
12 pub fn new(
13 provider_id: impl Into<String>,
14 provider: ProviderHandle,
15 ) -> Result<Self, ProviderResolutionError> {
16 let provider_id = provider_id.into();
17 let requested = provider_id.trim();
18 if requested.is_empty() {
19 return Err(ProviderResolutionError::MissingProviderId);
20 }
21 let actual = provider.kind();
22 if actual != requested {
23 return Err(ProviderResolutionError::ProviderIdMismatch {
24 expected: requested.to_string(),
25 actual: actual.to_string(),
26 });
27 }
28 Ok(Self {
29 provider_id: requested.to_string(),
30 provider,
31 })
32 }
33
34 pub fn from_provider(provider: ProviderHandle) -> Self {
35 Self {
36 provider_id: provider.kind().to_string(),
37 provider,
38 }
39 }
40}
41
42#[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)]
43pub enum ProviderResolutionError {
44 #[error("session policy does not specify provider_id")]
45 MissingProviderId,
46 #[error("provider `{provider_id}` is not registered with the runtime host")]
47 UnknownProvider { provider_id: String },
48 #[error("provider resolver returned `{actual}` for requested provider `{expected}`")]
49 ProviderIdMismatch { expected: String, actual: String },
50}
51
52pub trait RuntimeProviderResolver: Send + Sync {
53 fn resolve_provider_binding(
54 &self,
55 provider_id: &str,
56 ) -> Result<ProviderBinding, ProviderResolutionError>;
57}
58
59#[derive(Clone, Debug, Default)]
60pub struct EmptyProviderResolver;
61
62impl RuntimeProviderResolver for EmptyProviderResolver {
63 fn resolve_provider_binding(
64 &self,
65 provider_id: &str,
66 ) -> Result<ProviderBinding, ProviderResolutionError> {
67 let provider_id = provider_id.trim();
68 if provider_id.is_empty() {
69 return Err(ProviderResolutionError::MissingProviderId);
70 }
71 Err(ProviderResolutionError::UnknownProvider {
72 provider_id: provider_id.to_string(),
73 })
74 }
75}
76
77#[derive(Clone, Debug)]
78pub struct SingleProviderResolver {
79 provider_id: String,
80 provider: ProviderHandle,
81}
82
83impl SingleProviderResolver {
84 pub fn new(provider: ProviderHandle) -> Self {
85 Self {
86 provider_id: provider.kind().to_string(),
87 provider,
88 }
89 }
90
91 pub fn with_provider_id(provider_id: impl Into<String>, provider: ProviderHandle) -> Self {
92 Self {
93 provider_id: provider_id.into(),
94 provider,
95 }
96 }
97}
98
99impl RuntimeProviderResolver for SingleProviderResolver {
100 fn resolve_provider_binding(
101 &self,
102 provider_id: &str,
103 ) -> Result<ProviderBinding, ProviderResolutionError> {
104 let requested = provider_id.trim();
105 if requested.is_empty() {
106 return Err(ProviderResolutionError::MissingProviderId);
107 }
108 if requested != self.provider_id {
109 return Err(ProviderResolutionError::UnknownProvider {
110 provider_id: requested.to_string(),
111 });
112 }
113 ProviderBinding::new(requested, self.provider.clone())
114 }
115}
116
117#[derive(Clone, Debug, Default)]
118pub struct MapProviderResolver {
119 providers: BTreeMap<String, ProviderHandle>,
120}
121
122impl MapProviderResolver {
123 pub fn new() -> Self {
124 Self::default()
125 }
126
127 pub fn with_provider(mut self, provider: ProviderHandle) -> Self {
128 self.providers.insert(provider.kind().to_string(), provider);
129 self
130 }
131
132 pub fn with_provider_id(
133 mut self,
134 provider_id: impl Into<String>,
135 provider: ProviderHandle,
136 ) -> Self {
137 self.providers.insert(provider_id.into(), provider);
138 self
139 }
140}
141
142impl RuntimeProviderResolver for MapProviderResolver {
143 fn resolve_provider_binding(
144 &self,
145 provider_id: &str,
146 ) -> Result<ProviderBinding, ProviderResolutionError> {
147 let requested = provider_id.trim();
148 if requested.is_empty() {
149 return Err(ProviderResolutionError::MissingProviderId);
150 }
151 let provider = self.providers.get(requested).ok_or_else(|| {
152 ProviderResolutionError::UnknownProvider {
153 provider_id: requested.to_string(),
154 }
155 })?;
156 ProviderBinding::new(requested, provider.clone())
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163
164 fn provider(kind: &'static str) -> ProviderHandle {
165 crate::testing::TestProvider::builder()
166 .kind(kind)
167 .build()
168 .into_handle()
169 }
170
171 #[test]
172 fn map_provider_resolver_returns_registered_provider() {
173 let resolver = MapProviderResolver::new().with_provider(provider("mock"));
174
175 let resolved = resolver
176 .resolve_provider_binding("mock")
177 .expect("registered provider resolves");
178
179 assert_eq!(resolved.provider_id, "mock");
180 assert_eq!(resolved.provider.kind(), "mock");
181 }
182
183 #[test]
184 fn map_provider_resolver_reports_missing_provider_id() {
185 let resolver = MapProviderResolver::new().with_provider(provider("mock"));
186
187 let err = resolver
188 .resolve_provider_binding(" ")
189 .expect_err("empty provider id is rejected");
190
191 assert_eq!(err, ProviderResolutionError::MissingProviderId);
192 }
193
194 #[test]
195 fn map_provider_resolver_reports_unknown_provider_id() {
196 let resolver = MapProviderResolver::new().with_provider(provider("mock"));
197
198 let err = resolver
199 .resolve_provider_binding("other")
200 .expect_err("unknown provider id is rejected");
201
202 assert_eq!(
203 err,
204 ProviderResolutionError::UnknownProvider {
205 provider_id: "other".to_string()
206 }
207 );
208 }
209
210 #[test]
211 fn map_provider_resolver_reports_provider_id_mismatch() {
212 let resolver = MapProviderResolver::new().with_provider_id("recorded", provider("actual"));
213
214 let err = resolver
215 .resolve_provider_binding("recorded")
216 .expect_err("mismatched live provider is rejected");
217
218 assert_eq!(
219 err,
220 ProviderResolutionError::ProviderIdMismatch {
221 expected: "recorded".to_string(),
222 actual: "actual".to_string()
223 }
224 );
225 }
226}