Skip to main content

pi/connectors/
mod.rs

1//! Connectors for extension hostcalls.
2//!
3//! Connectors provide capability-gated access to host resources (HTTP, filesystem, etc.)
4//! for extensions. Each connector validates requests against policy before execution.
5//!
6//! Hostcall ABI types are re-exported from `crate::extensions` so protocol
7//! serialization stays canonical across runtime and connector boundaries.
8
9pub mod http;
10
11use crate::error::Result;
12use async_trait::async_trait;
13use serde_json::{Value, json};
14
15pub use crate::extensions::{
16    HostCallError, HostCallErrorCode, HostCallPayload, HostResultPayload, HostStreamBackpressure,
17    HostStreamChunk,
18};
19
20/// Trait for connectors that handle hostcalls from extensions.
21#[async_trait]
22pub trait Connector: Send + Sync {
23    /// The capability name this connector handles (e.g., "http", "fs").
24    fn capability(&self) -> &'static str;
25
26    /// Dispatch a hostcall to this connector.
27    ///
28    /// Returns `HostResultPayload` with either success output or error details.
29    async fn dispatch(&self, call: &HostCallPayload) -> Result<HostResultPayload>;
30}
31
32/// Helper to create a successful host result.
33pub fn host_result_ok(call_id: &str, output: Value) -> HostResultPayload {
34    HostResultPayload {
35        call_id: call_id.to_string(),
36        output,
37        is_error: false,
38        error: None,
39        chunk: None,
40    }
41}
42
43/// Helper to create an error host result.
44pub fn host_result_err(
45    call_id: &str,
46    code: HostCallErrorCode,
47    message: impl Into<String>,
48    retryable: Option<bool>,
49) -> HostResultPayload {
50    HostResultPayload {
51        call_id: call_id.to_string(),
52        output: json!({}),
53        is_error: true,
54        error: Some(HostCallError {
55            code,
56            message: message.into(),
57            details: None,
58            retryable,
59        }),
60        chunk: None,
61    }
62}
63
64/// Helper to create an error host result with details.
65pub fn host_result_err_with_details(
66    call_id: &str,
67    code: HostCallErrorCode,
68    message: impl Into<String>,
69    details: Value,
70    retryable: Option<bool>,
71) -> HostResultPayload {
72    HostResultPayload {
73        call_id: call_id.to_string(),
74        output: json!({}),
75        is_error: true,
76        error: Some(HostCallError {
77            code,
78            message: message.into(),
79            details: Some(details),
80            retryable,
81        }),
82        chunk: None,
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89    use serde_json::json;
90
91    #[test]
92    fn host_result_err_output_is_object() {
93        let result = host_result_err("c1", HostCallErrorCode::Io, "fail", None);
94        assert!(result.is_error);
95        assert!(
96            result.output.is_object(),
97            "error output must be object, got {:?}",
98            result.output
99        );
100    }
101
102    #[test]
103    fn host_result_err_with_details_output_is_object() {
104        let result = host_result_err_with_details(
105            "c2",
106            HostCallErrorCode::Denied,
107            "nope",
108            json!({"key": "val"}),
109            Some(true),
110        );
111        assert!(result.is_error);
112        assert!(
113            result.output.is_object(),
114            "error output must be object, got {:?}",
115            result.output
116        );
117    }
118
119    #[test]
120    fn host_result_ok_output_is_preserved() {
121        let payload = json!({"data": 42});
122        let result = host_result_ok("c3", payload.clone());
123        assert!(!result.is_error);
124        assert_eq!(result.output, payload);
125    }
126
127    #[test]
128    fn all_error_codes_produce_object_output() {
129        let codes = [
130            HostCallErrorCode::Timeout,
131            HostCallErrorCode::Denied,
132            HostCallErrorCode::Io,
133            HostCallErrorCode::InvalidRequest,
134            HostCallErrorCode::Internal,
135        ];
136        for code in codes {
137            let result = host_result_err("c4", code, "msg", None);
138            assert!(
139                result.output.is_object(),
140                "code={code:?} must produce object output"
141            );
142            let result_d = host_result_err_with_details("c5", code, "msg", json!({}), None);
143            assert!(
144                result_d.output.is_object(),
145                "code={code:?} with details must produce object output"
146            );
147        }
148    }
149
150    #[test]
151    fn connectors_hostcall_types_are_canonical_extension_types() {
152        fn accepts_extension_call(_: crate::extensions::HostCallPayload) {}
153        fn accepts_extension_result(_: crate::extensions::HostResultPayload) {}
154
155        let call = HostCallPayload {
156            call_id: "c6".to_string(),
157            capability: "http".to_string(),
158            method: "http.fetch".to_string(),
159            params: json!({"url": "https://example.com"}),
160            timeout_ms: Some(1000),
161            cancel_token: None,
162            context: None,
163        };
164        accepts_extension_call(call);
165
166        let result = HostResultPayload {
167            call_id: "c7".to_string(),
168            output: json!({"ok": true}),
169            is_error: false,
170            error: None,
171            chunk: Some(HostStreamChunk {
172                index: 1,
173                is_last: false,
174                backpressure: Some(HostStreamBackpressure {
175                    credits: Some(8),
176                    delay_ms: Some(5),
177                }),
178            }),
179        };
180        accepts_extension_result(result);
181    }
182
183    #[test]
184    fn host_stream_chunk_serializes_backpressure_fields() {
185        let chunk = HostStreamChunk {
186            index: 2,
187            is_last: false,
188            backpressure: Some(HostStreamBackpressure {
189                credits: Some(4),
190                delay_ms: Some(25),
191            }),
192        };
193
194        let value = serde_json::to_value(&chunk).expect("serialize host stream chunk");
195        assert_eq!(value["index"], json!(2));
196        assert_eq!(value["is_last"], json!(false));
197        assert_eq!(value["backpressure"]["credits"], json!(4));
198        assert_eq!(value["backpressure"]["delay_ms"], json!(25));
199    }
200
201    #[test]
202    fn protocol_schema_still_declares_host_stream_backpressure_and_object_output() {
203        let schema: Value =
204            serde_json::from_str(include_str!("../../docs/schema/extension_protocol.json"))
205                .expect("parse extension protocol schema");
206        let defs = schema
207            .get("$defs")
208            .and_then(Value::as_object)
209            .expect("schema $defs object");
210
211        let host_stream_chunk = defs
212            .get("host_stream_chunk")
213            .and_then(|v| v.get("properties"))
214            .and_then(Value::as_object)
215            .expect("host_stream_chunk properties");
216        assert!(
217            host_stream_chunk.contains_key("backpressure"),
218            "schema drift: host_stream_chunk.backpressure missing",
219        );
220
221        let output_type = defs
222            .get("host_result")
223            .and_then(|v| v.get("properties"))
224            .and_then(|v| v.get("output"))
225            .and_then(|v| v.get("type"))
226            .and_then(Value::as_str)
227            .expect("host_result.output.type");
228        assert_eq!(
229            output_type, "object",
230            "schema drift: host_result.output must remain object",
231        );
232    }
233}