Skip to main content

nemo_flow_adaptive/
plugin_component.rs

1// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Core plugin integration for the adaptive runtime.
5
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::{Arc, Mutex};
9
10use nemo_flow::plugin::{
11    ConfigDiagnostic, ConfigPolicy, DiagnosticLevel, Plugin, PluginComponentSpec, PluginError,
12    PluginRegistration, PluginRegistrationContext, Result, UnsupportedBehavior, deregister_plugin,
13    lookup_plugin, register_plugin,
14};
15use serde_json::{Map, Value as Json};
16
17use crate::config::AdaptiveConfig;
18use crate::error::AdaptiveError;
19use crate::runtime::features::AdaptiveRuntime;
20
21/// The plugin kind registered by the adaptive crate.
22pub const ADAPTIVE_PLUGIN_KIND: &str = "adaptive";
23
24/// One configured adaptive component.
25#[derive(Debug, Clone)]
26pub struct ComponentSpec {
27    /// Whether the adaptive component should be activated.
28    pub enabled: bool,
29    /// Adaptive config for this top-level component.
30    pub config: AdaptiveConfig,
31}
32
33impl ComponentSpec {
34    /// Creates an enabled adaptive component spec.
35    pub fn new(config: AdaptiveConfig) -> Self {
36        Self {
37            enabled: true,
38            config,
39        }
40    }
41}
42
43impl From<ComponentSpec> for PluginComponentSpec {
44    fn from(value: ComponentSpec) -> Self {
45        let Json::Object(config) =
46            serde_json::to_value(value.config).expect("adaptive config should serialize to object")
47        else {
48            unreachable!("adaptive config must serialize to object");
49        };
50
51        PluginComponentSpec {
52            kind: ADAPTIVE_PLUGIN_KIND.to_string(),
53            enabled: value.enabled,
54            config,
55        }
56    }
57}
58
59struct AdaptivePlugin;
60
61impl Plugin for AdaptivePlugin {
62    fn plugin_kind(&self) -> &str {
63        ADAPTIVE_PLUGIN_KIND
64    }
65
66    fn allows_multiple_components(&self) -> bool {
67        false
68    }
69
70    fn validate(&self, plugin_config: &Map<String, Json>) -> Vec<ConfigDiagnostic> {
71        validate_adaptive_plugin_config(plugin_config)
72    }
73
74    fn register<'a>(
75        &'a self,
76        plugin_config: &Map<String, Json>,
77        ctx: &'a mut PluginRegistrationContext,
78    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
79        let plugin_config = plugin_config.clone();
80        Box::pin(async move {
81            let config = parse_adaptive_config(&plugin_config)?;
82            let mut runtime = AdaptiveRuntime::new(config)
83                .await
84                .map_err(adaptive_to_plugin_error)?;
85            runtime.register().await.map_err(adaptive_to_plugin_error)?;
86
87            let runtime = Arc::new(Mutex::new(Some(runtime)));
88            ctx.add_registration(PluginRegistration::new(
89                ADAPTIVE_PLUGIN_KIND,
90                ADAPTIVE_PLUGIN_KIND,
91                Box::new(move || {
92                    let mut guard = runtime.lock().map_err(|err| {
93                        PluginError::Internal(format!(
94                            "adaptive runtime registration lock poisoned: {err}"
95                        ))
96                    })?;
97                    if let Some(mut runtime) = guard.take() {
98                        runtime.deregister().map_err(adaptive_to_plugin_error)?;
99                    }
100                    Ok(())
101                }),
102            ));
103            Ok(())
104        })
105    }
106}
107
108/// Registers the adaptive component kind in the core plugin registry.
109///
110/// Call this during startup before validating or initializing plugin configs
111/// that contain adaptive components.
112///
113/// # Returns
114/// A core plugin [`Result`] that is `Ok(())` when the adaptive component kind
115/// is available in the registry.
116///
117/// # Errors
118/// Returns an error when registration fails for a reason other than an already
119/// registered adaptive component.
120///
121/// # Notes
122/// Re-registering the adaptive component is treated as success when the
123/// existing registration already resolves to the adaptive plugin kind.
124pub fn register_adaptive_component() -> Result<()> {
125    match register_plugin(Arc::new(AdaptivePlugin)) {
126        Ok(()) => Ok(()),
127        Err(PluginError::RegistrationFailed(message))
128            if message.contains("already registered")
129                && lookup_plugin(ADAPTIVE_PLUGIN_KIND).is_some() =>
130        {
131            Ok(())
132        }
133        Err(err) => Err(err),
134    }
135}
136
137/// Deregisters the adaptive component kind from the core plugin registry.
138///
139/// This affects future validation and initialization only. Active adaptive
140/// runtime registrations remain until cleared or replaced.
141///
142/// # Returns
143/// `true` when the adaptive component kind was removed from the registry and
144/// `false` when it was not registered.
145///
146/// # Notes
147/// Active adaptive runtime registrations are not torn down by this function.
148pub fn deregister_adaptive_component() -> bool {
149    deregister_plugin(ADAPTIVE_PLUGIN_KIND)
150}
151
152fn parse_adaptive_config(plugin_config: &Map<String, Json>) -> Result<AdaptiveConfig> {
153    serde_json::from_value(Json::Object(plugin_config.clone()))
154        .map_err(|err| PluginError::InvalidConfig(format!("invalid adaptive plugin config: {err}")))
155}
156
157fn validate_adaptive_plugin_config(plugin_config: &Map<String, Json>) -> Vec<ConfigDiagnostic> {
158    let config = match parse_adaptive_config(plugin_config) {
159        Ok(config) => config,
160        Err(err) => {
161            return vec![ConfigDiagnostic {
162                level: DiagnosticLevel::Error,
163                code: "adaptive.invalid_plugin_config".to_string(),
164                component: Some(ADAPTIVE_PLUGIN_KIND.to_string()),
165                field: None,
166                message: err.to_string(),
167            }];
168        }
169    };
170
171    let mut diagnostics = vec![];
172    validate_unknown_fields(
173        &mut diagnostics,
174        &config.policy,
175        Some(ADAPTIVE_PLUGIN_KIND.to_string()),
176        plugin_config,
177        &[
178            "version",
179            "agent_id",
180            "state",
181            "telemetry",
182            "adaptive_hints",
183            "tool_parallelism",
184            "acg",
185            "policy",
186        ],
187    );
188
189    if let Some(policy_json) = plugin_config.get("policy").and_then(Json::as_object) {
190        validate_unknown_fields(
191            &mut diagnostics,
192            &config.policy,
193            Some("policy".to_string()),
194            policy_json,
195            &["unknown_component", "unknown_field", "unsupported_value"],
196        );
197    }
198
199    if let Some(state_json) = plugin_config.get("state").and_then(Json::as_object) {
200        validate_unknown_fields(
201            &mut diagnostics,
202            &config.policy,
203            Some("state".to_string()),
204            state_json,
205            &["backend"],
206        );
207        if let Some(backend_json) = state_json.get("backend").and_then(Json::as_object) {
208            validate_unknown_fields(
209                &mut diagnostics,
210                &config.policy,
211                Some("backend".to_string()),
212                backend_json,
213                &["kind", "config"],
214            );
215            let backend_kind = backend_json
216                .get("kind")
217                .and_then(Json::as_str)
218                .unwrap_or_default();
219            if let Some(backend_config_json) = backend_json.get("config").and_then(Json::as_object)
220            {
221                validate_backend_config_fields(
222                    &mut diagnostics,
223                    &config.policy,
224                    backend_kind,
225                    backend_config_json,
226                );
227            }
228        }
229    }
230
231    if let Some(telemetry_json) = plugin_config.get("telemetry").and_then(Json::as_object) {
232        validate_unknown_fields(
233            &mut diagnostics,
234            &config.policy,
235            Some("telemetry".to_string()),
236            telemetry_json,
237            &["subscriber_name", "learners"],
238        );
239    }
240
241    if let Some(adaptive_hints_json) = plugin_config
242        .get("adaptive_hints")
243        .and_then(Json::as_object)
244    {
245        validate_unknown_fields(
246            &mut diagnostics,
247            &config.policy,
248            Some("adaptive_hints".to_string()),
249            adaptive_hints_json,
250            &[
251                "priority",
252                "break_chain",
253                "inject_header",
254                "inject_body_path",
255            ],
256        );
257    }
258
259    if let Some(tool_parallelism_json) = plugin_config
260        .get("tool_parallelism")
261        .and_then(Json::as_object)
262    {
263        validate_unknown_fields(
264            &mut diagnostics,
265            &config.policy,
266            Some("tool_parallelism".to_string()),
267            tool_parallelism_json,
268            &["priority", "mode"],
269        );
270    }
271
272    if let Some(acg_json) = plugin_config.get("acg").and_then(Json::as_object) {
273        validate_unknown_fields(
274            &mut diagnostics,
275            &config.policy,
276            Some("acg".to_string()),
277            acg_json,
278            &[
279                "provider",
280                "observation_window",
281                "priority",
282                "stability_thresholds",
283            ],
284        );
285    }
286
287    diagnostics.extend(AdaptiveRuntime::validate_config(&config).diagnostics);
288    diagnostics
289}
290
291fn validate_backend_config_fields(
292    diagnostics: &mut Vec<ConfigDiagnostic>,
293    policy: &ConfigPolicy,
294    backend_kind: &str,
295    backend_config: &Map<String, Json>,
296) {
297    let known_fields: &[&str] = match backend_kind {
298        "in_memory" => &[],
299        "redis" => &["url", "key_prefix"],
300        _ => return,
301    };
302    validate_unknown_fields(
303        diagnostics,
304        policy,
305        Some(backend_kind.to_string()),
306        backend_config,
307        known_fields,
308    );
309}
310
311fn validate_unknown_fields(
312    diagnostics: &mut Vec<ConfigDiagnostic>,
313    policy: &ConfigPolicy,
314    component: Option<String>,
315    config: &Map<String, Json>,
316    known_fields: &[&str],
317) {
318    for field in config.keys() {
319        if !known_fields.contains(&field.as_str()) {
320            push_policy_diag(
321                diagnostics,
322                policy.unknown_field,
323                "adaptive.unknown_field",
324                component.clone(),
325                Some(field.clone()),
326                format!(
327                    "field '{}' is not recognized for '{}'",
328                    field,
329                    component.as_deref().unwrap_or("unknown")
330                ),
331            );
332        }
333    }
334}
335
336fn push_policy_diag(
337    diagnostics: &mut Vec<ConfigDiagnostic>,
338    behavior: UnsupportedBehavior,
339    code: &str,
340    component: Option<String>,
341    field: Option<String>,
342    message: String,
343) {
344    let level = match behavior {
345        UnsupportedBehavior::Ignore => return,
346        UnsupportedBehavior::Warn => DiagnosticLevel::Warning,
347        UnsupportedBehavior::Error => DiagnosticLevel::Error,
348    };
349
350    diagnostics.push(ConfigDiagnostic {
351        level,
352        code: code.to_string(),
353        component,
354        field,
355        message,
356    });
357}
358
359fn adaptive_to_plugin_error(err: AdaptiveError) -> PluginError {
360    match err {
361        AdaptiveError::InvalidConfig(message) => PluginError::InvalidConfig(message),
362        AdaptiveError::NotFound(message) => PluginError::NotFound(message),
363        AdaptiveError::Storage(message) => PluginError::Internal(message),
364        AdaptiveError::Serialization(err) => PluginError::Serialization(err),
365        AdaptiveError::Internal(message) => PluginError::Internal(message),
366        AdaptiveError::RegistrationFailed(message) => PluginError::RegistrationFailed(message),
367        AdaptiveError::ChannelClosed(message) => PluginError::Internal(message),
368        #[cfg(feature = "redis-backend")]
369        AdaptiveError::Redis(err) => PluginError::Internal(err.to_string()),
370    }
371}
372
373#[cfg(test)]
374#[path = "../tests/unit/plugin_component_tests.rs"]
375mod tests;