1use std::future::Future;
7use std::pin::Pin;
8use std::sync::{Arc, Mutex};
9
10use nemo_flow::plugin::{
11 ConfigDiagnostic, ConfigPolicy, DiagnosticLevel, Plugin, PluginComponentSpec, PluginError,
12 PluginRegistration, PluginRegistrationContext, Result, UnsupportedBehavior, deregister_plugin,
13 lookup_plugin, register_plugin,
14};
15use serde_json::{Map, Value as Json};
16
17use crate::config::AdaptiveConfig;
18use crate::error::AdaptiveError;
19use crate::runtime::features::AdaptiveRuntime;
20
21pub const ADAPTIVE_PLUGIN_KIND: &str = "adaptive";
23
24#[derive(Debug, Clone)]
26pub struct ComponentSpec {
27 pub enabled: bool,
29 pub config: AdaptiveConfig,
31}
32
33impl ComponentSpec {
34 pub fn new(config: AdaptiveConfig) -> Self {
36 Self {
37 enabled: true,
38 config,
39 }
40 }
41}
42
43impl From<ComponentSpec> for PluginComponentSpec {
44 fn from(value: ComponentSpec) -> Self {
45 let Json::Object(config) =
46 serde_json::to_value(value.config).expect("adaptive config should serialize to object")
47 else {
48 unreachable!("adaptive config must serialize to object");
49 };
50
51 PluginComponentSpec {
52 kind: ADAPTIVE_PLUGIN_KIND.to_string(),
53 enabled: value.enabled,
54 config,
55 }
56 }
57}
58
59struct AdaptivePlugin;
60
61impl Plugin for AdaptivePlugin {
62 fn plugin_kind(&self) -> &str {
63 ADAPTIVE_PLUGIN_KIND
64 }
65
66 fn allows_multiple_components(&self) -> bool {
67 false
68 }
69
70 fn validate(&self, plugin_config: &Map<String, Json>) -> Vec<ConfigDiagnostic> {
71 validate_adaptive_plugin_config(plugin_config)
72 }
73
74 fn register<'a>(
75 &'a self,
76 plugin_config: &Map<String, Json>,
77 ctx: &'a mut PluginRegistrationContext,
78 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
79 let plugin_config = plugin_config.clone();
80 Box::pin(async move {
81 let config = parse_adaptive_config(&plugin_config)?;
82 let mut runtime = AdaptiveRuntime::new(config)
83 .await
84 .map_err(adaptive_to_plugin_error)?;
85 runtime.register().await.map_err(adaptive_to_plugin_error)?;
86
87 let runtime = Arc::new(Mutex::new(Some(runtime)));
88 ctx.add_registration(PluginRegistration::new(
89 ADAPTIVE_PLUGIN_KIND,
90 ADAPTIVE_PLUGIN_KIND,
91 Box::new(move || {
92 let mut guard = runtime.lock().map_err(|err| {
93 PluginError::Internal(format!(
94 "adaptive runtime registration lock poisoned: {err}"
95 ))
96 })?;
97 if let Some(mut runtime) = guard.take() {
98 runtime.deregister().map_err(adaptive_to_plugin_error)?;
99 }
100 Ok(())
101 }),
102 ));
103 Ok(())
104 })
105 }
106}
107
108pub fn register_adaptive_component() -> Result<()> {
125 match register_plugin(Arc::new(AdaptivePlugin)) {
126 Ok(()) => Ok(()),
127 Err(PluginError::RegistrationFailed(message))
128 if message.contains("already registered")
129 && lookup_plugin(ADAPTIVE_PLUGIN_KIND).is_some() =>
130 {
131 Ok(())
132 }
133 Err(err) => Err(err),
134 }
135}
136
137pub fn deregister_adaptive_component() -> bool {
149 deregister_plugin(ADAPTIVE_PLUGIN_KIND)
150}
151
152fn parse_adaptive_config(plugin_config: &Map<String, Json>) -> Result<AdaptiveConfig> {
153 serde_json::from_value(Json::Object(plugin_config.clone()))
154 .map_err(|err| PluginError::InvalidConfig(format!("invalid adaptive plugin config: {err}")))
155}
156
157fn validate_adaptive_plugin_config(plugin_config: &Map<String, Json>) -> Vec<ConfigDiagnostic> {
158 let config = match parse_adaptive_config(plugin_config) {
159 Ok(config) => config,
160 Err(err) => {
161 return vec![ConfigDiagnostic {
162 level: DiagnosticLevel::Error,
163 code: "adaptive.invalid_plugin_config".to_string(),
164 component: Some(ADAPTIVE_PLUGIN_KIND.to_string()),
165 field: None,
166 message: err.to_string(),
167 }];
168 }
169 };
170
171 let mut diagnostics = vec![];
172 validate_unknown_fields(
173 &mut diagnostics,
174 &config.policy,
175 Some(ADAPTIVE_PLUGIN_KIND.to_string()),
176 plugin_config,
177 &[
178 "version",
179 "agent_id",
180 "state",
181 "telemetry",
182 "adaptive_hints",
183 "tool_parallelism",
184 "acg",
185 "policy",
186 ],
187 );
188
189 if let Some(policy_json) = plugin_config.get("policy").and_then(Json::as_object) {
190 validate_unknown_fields(
191 &mut diagnostics,
192 &config.policy,
193 Some("policy".to_string()),
194 policy_json,
195 &["unknown_component", "unknown_field", "unsupported_value"],
196 );
197 }
198
199 if let Some(state_json) = plugin_config.get("state").and_then(Json::as_object) {
200 validate_unknown_fields(
201 &mut diagnostics,
202 &config.policy,
203 Some("state".to_string()),
204 state_json,
205 &["backend"],
206 );
207 if let Some(backend_json) = state_json.get("backend").and_then(Json::as_object) {
208 validate_unknown_fields(
209 &mut diagnostics,
210 &config.policy,
211 Some("backend".to_string()),
212 backend_json,
213 &["kind", "config"],
214 );
215 let backend_kind = backend_json
216 .get("kind")
217 .and_then(Json::as_str)
218 .unwrap_or_default();
219 if let Some(backend_config_json) = backend_json.get("config").and_then(Json::as_object)
220 {
221 validate_backend_config_fields(
222 &mut diagnostics,
223 &config.policy,
224 backend_kind,
225 backend_config_json,
226 );
227 }
228 }
229 }
230
231 if let Some(telemetry_json) = plugin_config.get("telemetry").and_then(Json::as_object) {
232 validate_unknown_fields(
233 &mut diagnostics,
234 &config.policy,
235 Some("telemetry".to_string()),
236 telemetry_json,
237 &["subscriber_name", "learners"],
238 );
239 }
240
241 if let Some(adaptive_hints_json) = plugin_config
242 .get("adaptive_hints")
243 .and_then(Json::as_object)
244 {
245 validate_unknown_fields(
246 &mut diagnostics,
247 &config.policy,
248 Some("adaptive_hints".to_string()),
249 adaptive_hints_json,
250 &[
251 "priority",
252 "break_chain",
253 "inject_header",
254 "inject_body_path",
255 ],
256 );
257 }
258
259 if let Some(tool_parallelism_json) = plugin_config
260 .get("tool_parallelism")
261 .and_then(Json::as_object)
262 {
263 validate_unknown_fields(
264 &mut diagnostics,
265 &config.policy,
266 Some("tool_parallelism".to_string()),
267 tool_parallelism_json,
268 &["priority", "mode"],
269 );
270 }
271
272 if let Some(acg_json) = plugin_config.get("acg").and_then(Json::as_object) {
273 validate_unknown_fields(
274 &mut diagnostics,
275 &config.policy,
276 Some("acg".to_string()),
277 acg_json,
278 &[
279 "provider",
280 "observation_window",
281 "priority",
282 "stability_thresholds",
283 ],
284 );
285 }
286
287 diagnostics.extend(AdaptiveRuntime::validate_config(&config).diagnostics);
288 diagnostics
289}
290
291fn validate_backend_config_fields(
292 diagnostics: &mut Vec<ConfigDiagnostic>,
293 policy: &ConfigPolicy,
294 backend_kind: &str,
295 backend_config: &Map<String, Json>,
296) {
297 let known_fields: &[&str] = match backend_kind {
298 "in_memory" => &[],
299 "redis" => &["url", "key_prefix"],
300 _ => return,
301 };
302 validate_unknown_fields(
303 diagnostics,
304 policy,
305 Some(backend_kind.to_string()),
306 backend_config,
307 known_fields,
308 );
309}
310
311fn validate_unknown_fields(
312 diagnostics: &mut Vec<ConfigDiagnostic>,
313 policy: &ConfigPolicy,
314 component: Option<String>,
315 config: &Map<String, Json>,
316 known_fields: &[&str],
317) {
318 for field in config.keys() {
319 if !known_fields.contains(&field.as_str()) {
320 push_policy_diag(
321 diagnostics,
322 policy.unknown_field,
323 "adaptive.unknown_field",
324 component.clone(),
325 Some(field.clone()),
326 format!(
327 "field '{}' is not recognized for '{}'",
328 field,
329 component.as_deref().unwrap_or("unknown")
330 ),
331 );
332 }
333 }
334}
335
336fn push_policy_diag(
337 diagnostics: &mut Vec<ConfigDiagnostic>,
338 behavior: UnsupportedBehavior,
339 code: &str,
340 component: Option<String>,
341 field: Option<String>,
342 message: String,
343) {
344 let level = match behavior {
345 UnsupportedBehavior::Ignore => return,
346 UnsupportedBehavior::Warn => DiagnosticLevel::Warning,
347 UnsupportedBehavior::Error => DiagnosticLevel::Error,
348 };
349
350 diagnostics.push(ConfigDiagnostic {
351 level,
352 code: code.to_string(),
353 component,
354 field,
355 message,
356 });
357}
358
359fn adaptive_to_plugin_error(err: AdaptiveError) -> PluginError {
360 match err {
361 AdaptiveError::InvalidConfig(message) => PluginError::InvalidConfig(message),
362 AdaptiveError::NotFound(message) => PluginError::NotFound(message),
363 AdaptiveError::Storage(message) => PluginError::Internal(message),
364 AdaptiveError::Serialization(err) => PluginError::Serialization(err),
365 AdaptiveError::Internal(message) => PluginError::Internal(message),
366 AdaptiveError::RegistrationFailed(message) => PluginError::RegistrationFailed(message),
367 AdaptiveError::ChannelClosed(message) => PluginError::Internal(message),
368 #[cfg(feature = "redis-backend")]
369 AdaptiveError::Redis(err) => PluginError::Internal(err.to_string()),
370 }
371}
372
373#[cfg(test)]
374#[path = "../tests/unit/plugin_component_tests.rs"]
375mod tests;