Skip to main content

nnrp_wasm/
lib.rs

1use nnrp_core::{ProtocolVersion, TransportId};
2use nnrp_transport_provider::{
3    select_transport_with_probe, ProbeSample, RemoteTransportSupport, TransportPolicy,
4    TransportProviderDescriptor, TransportProviderKind,
5};
6use serde::{Deserialize, Serialize};
7use wasm_bindgen::prelude::*;
8
9#[wasm_bindgen]
10pub fn nnrp_wasm_protocol_major() -> u8 {
11    ProtocolVersion::CURRENT.major
12}
13
14#[wasm_bindgen]
15pub fn nnrp_wasm_wire_format() -> u8 {
16    ProtocolVersion::CURRENT.wire_format
17}
18
19#[wasm_bindgen(js_name = selectTransportWithProbeJson)]
20pub fn select_transport_with_probe_json(
21    providers_json: &str,
22    remote_transports_json: &str,
23    policy: &str,
24    samples_json: &str,
25) -> Result<String, JsValue> {
26    let providers = parse_providers(providers_json)?;
27    let remote = RemoteTransportSupport::new(parse_transport_ids(remote_transports_json)?);
28    let policy = parse_policy(policy)?;
29    let samples = parse_probe_samples(samples_json)?;
30    let selection = select_transport_with_probe(&providers, &remote, policy, &samples)
31        .map_err(|error| js_error(&error.to_string()))?;
32
33    serde_json::to_string(&WasmTransportSelection::from(selection))
34        .map_err(|error| js_error(&error.to_string()))
35}
36
37#[wasm_bindgen(js_name = scoreProviderProbeJson)]
38pub fn score_provider_probe_json(
39    provider_json: &str,
40    policy: &str,
41    samples_json: &str,
42) -> Result<String, JsValue> {
43    let provider = parse_provider(provider_json)?;
44    let policy = parse_policy(policy)?;
45    let samples = parse_probe_samples(samples_json)?;
46    let score = nnrp_transport_provider::score_provider_probe(&provider, &samples, policy)
47        .ok_or_else(|| js_error("no matching probe samples for provider"))?;
48
49    serde_json::to_string(&WasmProbeScore::from(score))
50        .map_err(|error| js_error(&error.to_string()))
51}
52
53#[derive(Debug, Deserialize)]
54struct WasmProviderInput {
55    name: String,
56    version: String,
57    transport_id: u32,
58    kind: Option<String>,
59    available: Option<bool>,
60    diagnostic: Option<String>,
61}
62
63#[derive(Debug, Deserialize)]
64struct WasmProbeSampleInput {
65    transport_id: u32,
66    provider_name: String,
67    elapsed_us: u64,
68    rtt_us: Option<u64>,
69    bytes_sent: u64,
70    bytes_received: u64,
71    timed_out: Option<bool>,
72    failed: Option<bool>,
73}
74
75#[derive(Debug, Serialize)]
76struct WasmTransportSelection {
77    selected: WasmProviderOutput,
78    selected_score: WasmProbeScore,
79    candidates: Vec<WasmCandidateScore>,
80    rejected: Vec<WasmRejectedCandidate>,
81}
82
83impl From<nnrp_transport_provider::ProbeSelection> for WasmTransportSelection {
84    fn from(value: nnrp_transport_provider::ProbeSelection) -> Self {
85        Self {
86            selected: value.selected.into(),
87            selected_score: value.selected_score.into(),
88            candidates: value.candidates.into_iter().map(Into::into).collect(),
89            rejected: value.rejected.into_iter().map(Into::into).collect(),
90        }
91    }
92}
93
94#[derive(Debug, Serialize)]
95struct WasmProviderOutput {
96    name: String,
97    version: String,
98    transport_id: u32,
99    kind: String,
100    available: bool,
101    diagnostic: Option<String>,
102}
103
104impl From<TransportProviderDescriptor> for WasmProviderOutput {
105    fn from(value: TransportProviderDescriptor) -> Self {
106        Self {
107            name: value.name,
108            version: value.version,
109            transport_id: value.transport_id as u32,
110            kind: provider_kind_name(value.kind).to_string(),
111            available: value.available,
112            diagnostic: value.diagnostic,
113        }
114    }
115}
116
117#[derive(Debug, Serialize)]
118struct WasmCandidateScore {
119    provider: WasmProviderOutput,
120    probe_score: WasmProbeScore,
121}
122
123impl From<nnrp_transport_provider::ProbeCandidateScore> for WasmCandidateScore {
124    fn from(value: nnrp_transport_provider::ProbeCandidateScore) -> Self {
125        Self {
126            provider: value.provider.into(),
127            probe_score: value.probe_score.into(),
128        }
129    }
130}
131
132#[derive(Debug, Serialize)]
133struct WasmProbeScore {
134    sample_count: usize,
135    failure_count: usize,
136    failure_rate: f64,
137    median_rtt_us: u64,
138    throughput_bytes_per_sec: u64,
139    score: f64,
140}
141
142impl From<nnrp_transport_provider::ProbeScore> for WasmProbeScore {
143    fn from(value: nnrp_transport_provider::ProbeScore) -> Self {
144        Self {
145            sample_count: value.sample_count,
146            failure_count: value.failure_count,
147            failure_rate: value.failure_rate,
148            median_rtt_us: value.median_rtt_us,
149            throughput_bytes_per_sec: value.throughput_bytes_per_sec,
150            score: value.score,
151        }
152    }
153}
154
155#[derive(Debug, Serialize)]
156struct WasmRejectedCandidate {
157    transport_id: u32,
158    provider_name: Option<String>,
159    reason: String,
160}
161
162impl From<nnrp_transport_provider::RejectedTransportCandidate> for WasmRejectedCandidate {
163    fn from(value: nnrp_transport_provider::RejectedTransportCandidate) -> Self {
164        Self {
165            transport_id: value.transport_id as u32,
166            provider_name: value.provider_name,
167            reason: format!("{:?}", value.reason),
168        }
169    }
170}
171
172fn parse_providers(source: &str) -> Result<Vec<TransportProviderDescriptor>, JsValue> {
173    let inputs = serde_json::from_str::<Vec<WasmProviderInput>>(source)
174        .map_err(|error| js_error(&error.to_string()))?;
175    inputs.into_iter().map(provider_from_input).collect()
176}
177
178fn parse_provider(source: &str) -> Result<TransportProviderDescriptor, JsValue> {
179    let input = serde_json::from_str::<WasmProviderInput>(source)
180        .map_err(|error| js_error(&error.to_string()))?;
181    provider_from_input(input)
182}
183
184fn provider_from_input(input: WasmProviderInput) -> Result<TransportProviderDescriptor, JsValue> {
185    let transport_id = parse_transport_id(input.transport_id)?;
186    let kind = parse_provider_kind(input.kind.as_deref().unwrap_or("wasm"))?;
187    if input.available.unwrap_or(true) {
188        Ok(TransportProviderDescriptor::available(
189            input.name,
190            input.version,
191            transport_id,
192            kind,
193        ))
194    } else {
195        Ok(TransportProviderDescriptor::missing(
196            input.name,
197            input.version,
198            transport_id,
199            kind,
200            input
201                .diagnostic
202                .unwrap_or_else(|| "provider is not available".to_string()),
203        ))
204    }
205}
206
207fn parse_probe_samples(source: &str) -> Result<Vec<ProbeSample>, JsValue> {
208    let inputs = serde_json::from_str::<Vec<WasmProbeSampleInput>>(source)
209        .map_err(|error| js_error(&error.to_string()))?;
210    inputs
211        .into_iter()
212        .map(|sample| {
213            Ok(ProbeSample {
214                transport_id: parse_transport_id(sample.transport_id)?,
215                provider_name: sample.provider_name,
216                elapsed_us: sample.elapsed_us,
217                rtt_us: sample.rtt_us,
218                bytes_sent: sample.bytes_sent,
219                bytes_received: sample.bytes_received,
220                timed_out: sample.timed_out.unwrap_or(false),
221                failed: sample.failed.unwrap_or(false),
222            })
223        })
224        .collect()
225}
226
227fn parse_transport_ids(source: &str) -> Result<Vec<TransportId>, JsValue> {
228    let ids =
229        serde_json::from_str::<Vec<u32>>(source).map_err(|error| js_error(&error.to_string()))?;
230    ids.into_iter().map(parse_transport_id).collect()
231}
232
233fn parse_transport_id(value: u32) -> Result<TransportId, JsValue> {
234    TransportId::try_from_u32(value)
235        .map_err(|error| js_error(&format!("invalid transport id: {error}")))
236}
237
238fn parse_policy(value: &str) -> Result<TransportPolicy, JsValue> {
239    match value {
240        "auto" => Ok(TransportPolicy::Auto),
241        "prefer_quic" => Ok(TransportPolicy::PreferQuic),
242        "prefer_tcp" => Ok(TransportPolicy::PreferTcp),
243        "force_quic" => Ok(TransportPolicy::ForceQuic),
244        "force_tcp" => Ok(TransportPolicy::ForceTcp),
245        other => Err(js_error(&format!("unknown transport policy: {other}"))),
246    }
247}
248
249fn parse_provider_kind(value: &str) -> Result<TransportProviderKind, JsValue> {
250    match value {
251        "pure_rust" => Ok(TransportProviderKind::PureRust),
252        "native_dynamic" => Ok(TransportProviderKind::NativeDynamic),
253        "wasm" => Ok(TransportProviderKind::Wasm),
254        other => Err(js_error(&format!("unknown provider kind: {other}"))),
255    }
256}
257
258fn provider_kind_name(kind: TransportProviderKind) -> &'static str {
259    match kind {
260        TransportProviderKind::PureRust => "pure_rust",
261        TransportProviderKind::NativeDynamic => "native_dynamic",
262        TransportProviderKind::Wasm => "wasm",
263    }
264}
265
266fn js_error(message: &str) -> JsValue {
267    #[cfg(target_arch = "wasm32")]
268    {
269        JsValue::from_str(message)
270    }
271    #[cfg(not(target_arch = "wasm32"))]
272    {
273        let _ = message;
274        JsValue::NULL
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use serde_json::Value;
282
283    #[test]
284    fn wasm_protocol_version_exports_current_values() {
285        assert_eq!(nnrp_wasm_protocol_major(), ProtocolVersion::CURRENT.major);
286        assert_eq!(
287            nnrp_wasm_wire_format(),
288            ProtocolVersion::CURRENT.wire_format
289        );
290    }
291
292    #[test]
293    fn wasm_probe_selection_prefers_measured_tcp_over_flaky_quic() {
294        let providers = r#"[
295            {"name":"tcp","version":"0.0.0","transport_id":2,"kind":"wasm","available":true},
296            {"name":"quic","version":"0.0.0","transport_id":1,"kind":"wasm","available":true}
297        ]"#;
298        let samples = r#"[
299            {"transport_id":2,"provider_name":"tcp","elapsed_us":20000,"rtt_us":5000,"bytes_sent":1024,"bytes_received":1024},
300            {"transport_id":2,"provider_name":"tcp","elapsed_us":20000,"rtt_us":5100,"bytes_sent":1024,"bytes_received":1024},
301            {"transport_id":1,"provider_name":"quic","elapsed_us":20000,"rtt_us":800,"bytes_sent":1024,"bytes_received":1024},
302            {"transport_id":1,"provider_name":"quic","elapsed_us":20000,"rtt_us":null,"bytes_sent":0,"bytes_received":0,"timed_out":true,"failed":true}
303        ]"#;
304
305        let output =
306            select_transport_with_probe_json(providers, "[1,2]", "prefer_quic", samples).unwrap();
307        let output = serde_json::from_str::<Value>(&output).unwrap();
308        assert_eq!(output["selected"]["transport_id"], 2);
309        assert_eq!(output["candidates"].as_array().unwrap().len(), 2);
310    }
311
312    #[test]
313    fn wasm_probe_selection_reports_rejected_unavailable_provider() {
314        let providers = r#"[
315            {"name":"tcp-native","version":"0.0.0","transport_id":2,"kind":"native_dynamic","available":true},
316            {"name":"quic-native","version":"0.0.0","transport_id":1,"kind":"pure_rust","available":false,"diagnostic":"backend missing"}
317        ]"#;
318        let samples = r#"[
319            {"transport_id":2,"provider_name":"tcp-native","elapsed_us":10000,"rtt_us":2500,"bytes_sent":4096,"bytes_received":4096}
320        ]"#;
321
322        let output =
323            select_transport_with_probe_json(providers, "[1,2]", "prefer_tcp", samples).unwrap();
324        let output = serde_json::from_str::<Value>(&output).unwrap();
325
326        assert_eq!(output["selected"]["kind"], "native_dynamic");
327        assert_eq!(output["rejected"][0]["transport_id"], 1);
328        assert_eq!(output["rejected"][0]["provider_name"], "quic-native");
329        assert!(output["rejected"][0]["reason"]
330            .as_str()
331            .unwrap()
332            .contains("LocalProviderUnavailable"));
333    }
334
335    #[test]
336    fn wasm_score_provider_probe_returns_json_score() {
337        let provider = r#"{"name":"quic","version":"0.0.0","transport_id":1,"kind":"pure_rust","available":true}"#;
338        let samples = r#"[
339            {"transport_id":1,"provider_name":"quic","elapsed_us":8000,"rtt_us":1000,"bytes_sent":2048,"bytes_received":2048},
340            {"transport_id":1,"provider_name":"quic","elapsed_us":9000,"rtt_us":1200,"bytes_sent":2048,"bytes_received":2048}
341        ]"#;
342
343        let output = score_provider_probe_json(provider, "force_quic", samples).unwrap();
344        let output = serde_json::from_str::<Value>(&output).unwrap();
345
346        assert_eq!(output["sample_count"], 2);
347        assert_eq!(output["failure_count"], 0);
348        assert_eq!(output["median_rtt_us"], 1200);
349    }
350
351    #[test]
352    fn wasm_score_provider_probe_reports_missing_samples() {
353        let provider =
354            r#"{"name":"tcp","version":"0.0.0","transport_id":2,"kind":"wasm","available":true}"#;
355        assert!(score_provider_probe_json(provider, "auto", "[]").is_err());
356    }
357
358    #[test]
359    fn wasm_rejects_invalid_policy_kind_and_transport_id() {
360        let tcp =
361            r#"{"name":"tcp","version":"0.0.0","transport_id":2,"kind":"wasm","available":true}"#;
362        let bad_kind =
363            r#"{"name":"tcp","version":"0.0.0","transport_id":2,"kind":"plugin","available":true}"#;
364        let bad_transport =
365            r#"{"name":"tcp","version":"0.0.0","transport_id":99,"kind":"wasm","available":true}"#;
366
367        assert!(score_provider_probe_json(tcp, "sticky", "[]").is_err());
368        assert!(score_provider_probe_json(bad_kind, "force_tcp", "[]").is_err());
369        assert!(score_provider_probe_json(bad_transport, "force_tcp", "[]").is_err());
370    }
371}