neuron_types/traits.rs
1//! Core traits: Provider, Tool, ToolDyn, ContextStrategy, ObservabilityHook, DurableContext.
2
3use std::future::Future;
4use std::time::Duration;
5
6use serde::Serialize;
7use serde::de::DeserializeOwned;
8
9use crate::error::{
10 ContextError, DurableError, EmbeddingError, HookError, ProviderError, ToolError,
11};
12use crate::stream::StreamHandle;
13use crate::types::{
14 CompletionRequest, CompletionResponse, ContentItem, EmbeddingRequest, EmbeddingResponse,
15 Message, ToolContext, ToolDefinition, ToolOutput,
16};
17use crate::wasm::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync};
18
19/// LLM provider trait. Implement this for each provider (Anthropic, OpenAI, Ollama, etc.).
20///
21/// Uses RPITIT (return position impl trait in trait) — Rust 2024 native async.
22/// Not object-safe by design; use generics `<P: Provider>` to compose.
23///
24/// # Example
25///
26/// ```no_run
27/// use std::future::Future;
28/// use neuron_types::*;
29///
30/// struct MyProvider;
31///
32/// impl Provider for MyProvider {
33/// fn complete(&self, request: CompletionRequest)
34/// -> impl Future<Output = Result<CompletionResponse, ProviderError>> + Send
35/// {
36/// async { todo!() }
37/// }
38///
39/// fn complete_stream(&self, request: CompletionRequest)
40/// -> impl Future<Output = Result<StreamHandle, ProviderError>> + Send
41/// {
42/// async { todo!() }
43/// }
44/// }
45/// ```
46pub trait Provider: WasmCompatSend + WasmCompatSync {
47 /// Send a completion request and get a full response.
48 fn complete(
49 &self,
50 request: CompletionRequest,
51 ) -> impl Future<Output = Result<CompletionResponse, ProviderError>> + WasmCompatSend;
52
53 /// Send a completion request and get a stream of events.
54 fn complete_stream(
55 &self,
56 request: CompletionRequest,
57 ) -> impl Future<Output = Result<StreamHandle, ProviderError>> + WasmCompatSend;
58}
59
60/// Embedding provider trait. Implement this for providers that support text embeddings.
61///
62/// Kept separate from [`Provider`] because not all embedding models support chat completion
63/// and not all chat providers support embeddings. Implement both traits on a struct when a
64/// provider supports both capabilities.
65///
66/// Uses RPITIT (return position impl trait in trait) — Rust 2024 native async.
67/// Not object-safe by design; use generics `<E: EmbeddingProvider>` to compose.
68///
69/// # Example
70///
71/// ```no_run
72/// use std::future::Future;
73/// use neuron_types::*;
74///
75/// struct MyEmbeddingProvider;
76///
77/// impl EmbeddingProvider for MyEmbeddingProvider {
78/// fn embed(&self, request: EmbeddingRequest)
79/// -> impl Future<Output = Result<EmbeddingResponse, EmbeddingError>> + Send
80/// {
81/// async { todo!() }
82/// }
83/// }
84/// ```
85pub trait EmbeddingProvider: WasmCompatSend + WasmCompatSync {
86 /// Send an embedding request and get vectors back.
87 ///
88 /// Multiple input strings are batched into a single request. The returned
89 /// `embeddings` vec is in the same order as `request.input`.
90 fn embed(
91 &self,
92 request: EmbeddingRequest,
93 ) -> impl Future<Output = Result<EmbeddingResponse, EmbeddingError>> + WasmCompatSend;
94}
95
96/// Strongly-typed tool trait. Implement this for your tools.
97///
98/// The blanket impl of [`ToolDyn`] handles JSON deserialization/serialization
99/// so you work with concrete Rust types.
100///
101/// # Example
102///
103/// ```no_run
104/// use neuron_types::*;
105/// use serde::Deserialize;
106///
107/// #[derive(Debug, Deserialize, schemars::JsonSchema)]
108/// struct MyArgs { query: String }
109///
110/// struct MyTool;
111/// impl Tool for MyTool {
112/// const NAME: &'static str = "my_tool";
113/// type Args = MyArgs;
114/// type Output = String;
115/// type Error = std::io::Error;
116///
117/// fn definition(&self) -> ToolDefinition { todo!() }
118/// fn call(&self, args: MyArgs, ctx: &ToolContext)
119/// -> impl Future<Output = Result<String, std::io::Error>> + Send
120/// { async { Ok(args.query) } }
121/// }
122/// ```
123pub trait Tool: WasmCompatSend + WasmCompatSync {
124 /// The unique name of this tool.
125 const NAME: &'static str;
126 /// The deserialized input type.
127 type Args: DeserializeOwned + schemars::JsonSchema + WasmCompatSend;
128 /// The serializable output type.
129 type Output: Serialize;
130 /// The tool-specific error type.
131 type Error: std::error::Error + WasmCompatSend + 'static;
132
133 /// Returns the tool definition (name, description, schema).
134 fn definition(&self) -> ToolDefinition;
135
136 /// Execute the tool with typed arguments.
137 fn call(
138 &self,
139 args: Self::Args,
140 ctx: &ToolContext,
141 ) -> impl Future<Output = Result<Self::Output, Self::Error>> + WasmCompatSend;
142}
143
144/// Type-erased tool for dynamic dispatch. Blanket-implemented for all [`Tool`] impls.
145///
146/// This enables heterogeneous tool collections (`HashMap<String, Arc<dyn ToolDyn>>`)
147/// while preserving type safety at the implementation level.
148pub trait ToolDyn: WasmCompatSend + WasmCompatSync {
149 /// The tool's unique name.
150 fn name(&self) -> &str;
151 /// The tool definition (name, description, input schema).
152 fn definition(&self) -> ToolDefinition;
153 /// Execute the tool with a JSON value input, returning a generic output.
154 fn call_dyn<'a>(
155 &'a self,
156 input: serde_json::Value,
157 ctx: &'a ToolContext,
158 ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>>;
159}
160
161/// Blanket implementation: any `Tool` automatically becomes a `ToolDyn`.
162///
163/// Handles:
164/// - Deserializing `serde_json::Value` into `T::Args`
165/// - Calling `T::call(args, ctx)`
166/// - Serializing `T::Output` into `ToolOutput`
167/// - Mapping `T::Error` into `ToolError::ExecutionFailed`
168impl<T: Tool> ToolDyn for T {
169 fn name(&self) -> &str {
170 T::NAME
171 }
172
173 fn definition(&self) -> ToolDefinition {
174 Tool::definition(self)
175 }
176
177 fn call_dyn<'a>(
178 &'a self,
179 input: serde_json::Value,
180 ctx: &'a ToolContext,
181 ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
182 Box::pin(async move {
183 let args: T::Args = serde_json::from_value(input)
184 .map_err(|e| ToolError::InvalidInput(e.to_string()))?;
185
186 let output = self
187 .call(args, ctx)
188 .await
189 .map_err(|e| ToolError::ExecutionFailed(e.to_string().into()))?;
190
191 let structured = serde_json::to_value(&output)
192 .map_err(|e| ToolError::ExecutionFailed(Box::new(e)))?;
193
194 let text = match &structured {
195 serde_json::Value::String(s) => s.clone(),
196 other => other.to_string(),
197 };
198
199 Ok(ToolOutput {
200 content: vec![ContentItem::Text(text)],
201 structured_content: Some(structured),
202 is_error: false,
203 })
204 })
205 }
206}
207
208// --- Context Strategy ---
209
210/// Strategy for compacting conversation context when it exceeds token limits.
211///
212/// Implementations decide when to compact and how to reduce the message list.
213///
214/// # Example
215///
216/// ```no_run
217/// use std::future::Future;
218/// use neuron_types::*;
219///
220/// struct KeepLastN { n: usize }
221/// impl ContextStrategy for KeepLastN {
222/// fn should_compact(&self, _messages: &[Message], token_count: usize) -> bool {
223/// token_count > 100_000
224/// }
225/// fn compact(&self, messages: Vec<Message>)
226/// -> impl Future<Output = Result<Vec<Message>, ContextError>> + Send
227/// {
228/// async move { Ok(messages.into_iter().rev().take(self.n).collect()) }
229/// }
230/// fn token_estimate(&self, messages: &[Message]) -> usize { messages.len() * 100 }
231/// }
232/// ```
233pub trait ContextStrategy: WasmCompatSend + WasmCompatSync {
234 /// Whether compaction should be triggered given the current messages and token count.
235 fn should_compact(&self, messages: &[Message], token_count: usize) -> bool;
236
237 /// Compact the message list to reduce token usage.
238 fn compact(
239 &self,
240 messages: Vec<Message>,
241 ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + WasmCompatSend;
242
243 /// Estimate the token count for a list of messages.
244 fn token_estimate(&self, messages: &[Message]) -> usize;
245}
246
247// --- Observability Hooks ---
248
249/// Events fired during the agentic loop for observability.
250#[derive(Debug)]
251pub enum HookEvent<'a> {
252 /// Start of a loop iteration.
253 LoopIteration {
254 /// The current turn number (0-indexed).
255 turn: usize,
256 },
257 /// Before calling the LLM provider.
258 PreLlmCall {
259 /// The request about to be sent.
260 request: &'a CompletionRequest,
261 },
262 /// After receiving the LLM response.
263 PostLlmCall {
264 /// The response received.
265 response: &'a CompletionResponse,
266 },
267 /// Before executing a tool.
268 PreToolExecution {
269 /// Name of the tool.
270 tool_name: &'a str,
271 /// Input arguments.
272 input: &'a serde_json::Value,
273 },
274 /// After executing a tool.
275 PostToolExecution {
276 /// Name of the tool.
277 tool_name: &'a str,
278 /// The tool's output.
279 output: &'a ToolOutput,
280 },
281 /// Context was compacted.
282 ContextCompaction {
283 /// Token count before compaction.
284 old_tokens: usize,
285 /// Token count after compaction.
286 new_tokens: usize,
287 },
288 /// A session started.
289 SessionStart {
290 /// The session identifier.
291 session_id: &'a str,
292 },
293 /// A session ended.
294 SessionEnd {
295 /// The session identifier.
296 session_id: &'a str,
297 },
298}
299
300/// Action to take after processing a hook event.
301#[derive(Debug)]
302pub enum HookAction {
303 /// Continue normal execution.
304 Continue,
305 /// Skip the current operation and return a rejection message.
306 Skip {
307 /// Reason for skipping.
308 reason: String,
309 },
310 /// Terminate the loop immediately.
311 Terminate {
312 /// Reason for termination.
313 reason: String,
314 },
315}
316
317/// Hook for observability (logging, metrics, telemetry).
318///
319/// Does NOT control execution flow beyond Continue/Skip/Terminate.
320/// For durable execution wrapping, use [`DurableContext`] instead.
321///
322/// # Example
323///
324/// ```no_run
325/// use std::future::Future;
326/// use neuron_types::*;
327///
328/// struct LogHook;
329/// impl ObservabilityHook for LogHook {
330/// fn on_event(&self, event: HookEvent<'_>)
331/// -> impl Future<Output = Result<HookAction, HookError>> + Send
332/// {
333/// async move { println!("{event:?}"); Ok(HookAction::Continue) }
334/// }
335/// }
336/// ```
337pub trait ObservabilityHook: WasmCompatSend + WasmCompatSync {
338 /// Called for each event in the agentic loop.
339 fn on_event(
340 &self,
341 event: HookEvent<'_>,
342 ) -> impl Future<Output = Result<HookAction, HookError>> + WasmCompatSend;
343}
344
345// --- Durable Context ---
346
347/// Options for executing an activity in a durable context.
348#[derive(Debug, Clone)]
349pub struct ActivityOptions {
350 /// Maximum time for the activity to complete.
351 pub start_to_close_timeout: Duration,
352 /// Heartbeat interval for long-running activities.
353 pub heartbeat_timeout: Option<Duration>,
354 /// Retry policy for failed activities.
355 pub retry_policy: Option<RetryPolicy>,
356}
357
358/// Retry policy for durable activities.
359#[derive(Debug, Clone)]
360pub struct RetryPolicy {
361 /// Initial delay before first retry.
362 pub initial_interval: Duration,
363 /// Multiplier for exponential backoff.
364 pub backoff_coefficient: f64,
365 /// Maximum number of retry attempts.
366 pub maximum_attempts: u32,
367 /// Maximum delay between retries.
368 pub maximum_interval: Duration,
369 /// Error types that should not be retried.
370 pub non_retryable_errors: Vec<String>,
371}
372
373/// Wraps side effects for durable execution engines (Temporal, Restate, Inngest).
374///
375/// When present, the agentic loop calls through this instead of directly
376/// calling provider/tools, enabling journaling, replay, and crash recovery.
377pub trait DurableContext: WasmCompatSend + WasmCompatSync {
378 /// Execute an LLM call as a durable activity.
379 fn execute_llm_call(
380 &self,
381 request: CompletionRequest,
382 options: ActivityOptions,
383 ) -> impl Future<Output = Result<CompletionResponse, DurableError>> + WasmCompatSend;
384
385 /// Execute a tool call as a durable activity.
386 fn execute_tool(
387 &self,
388 tool_name: &str,
389 input: serde_json::Value,
390 ctx: &ToolContext,
391 options: ActivityOptions,
392 ) -> impl Future<Output = Result<ToolOutput, DurableError>> + WasmCompatSend;
393
394 /// Wait for an external signal with a timeout.
395 fn wait_for_signal<T: DeserializeOwned + WasmCompatSend>(
396 &self,
397 signal_name: &str,
398 timeout: Duration,
399 ) -> impl Future<Output = Result<Option<T>, DurableError>> + WasmCompatSend;
400
401 /// Whether the workflow should continue-as-new to avoid history bloat.
402 fn should_continue_as_new(&self) -> bool;
403
404 /// Continue the workflow as a new execution with the given state.
405 fn continue_as_new(
406 &self,
407 state: serde_json::Value,
408 ) -> impl Future<Output = Result<(), DurableError>> + WasmCompatSend;
409
410 /// Sleep for a duration (durable — survives replay).
411 fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + WasmCompatSend;
412
413 /// Current time (deterministic during replay).
414 fn now(&self) -> chrono::DateTime<chrono::Utc>;
415}
416
417// --- Permission Policy ---
418
419/// Decision from a permission check.
420#[derive(Debug, Clone)]
421pub enum PermissionDecision {
422 /// Allow the tool call.
423 Allow,
424 /// Deny the tool call with a reason.
425 Deny(String),
426 /// Ask the user for confirmation.
427 Ask(String),
428}
429
430/// Policy for checking tool call permissions.
431///
432/// # Example
433///
434/// ```no_run
435/// use neuron_types::*;
436///
437/// struct AllowAll;
438/// impl PermissionPolicy for AllowAll {
439/// fn check(&self, _tool_name: &str, _input: &serde_json::Value) -> PermissionDecision {
440/// PermissionDecision::Allow
441/// }
442/// }
443/// ```
444pub trait PermissionPolicy: WasmCompatSend + WasmCompatSync {
445 /// Check whether a tool call is permitted.
446 fn check(&self, tool_name: &str, input: &serde_json::Value) -> PermissionDecision;
447}