Skip to main content

control_batch_insert_semantic_direct/
control_batch_insert_semantic_direct.rs

1#[path = "support/utils.rs"]
2mod _utils;
3
4use _utils::{boxed_error, cleanup_run, create_client, new_run_id, print_summary, require};
5use mubit_sdk::TransportMode;
6use serde_json::{json, Value};
7use std::error::Error;
8use std::time::Instant;
9
10#[tokio::main(flavor = "current_thread")]
11async fn main() -> Result<(), Box<dyn Error>> {
12    let name = "control_batch_insert_semantic_direct";
13    let started = Instant::now();
14    let client = create_client().await?;
15    let run_id = new_run_id("control_batch_insert_semantic_direct");
16    client.set_run_id(Some(run_id.clone()));
17    client.set_transport(TransportMode::Http);
18
19    let mut passed = true;
20    let mut detail = "validated direct semantic control batch insert".to_string();
21    let mut metrics = json!({});
22
23    let scenario = async {
24        let response = client
25            .control
26            .batch_insert(json!({
27                "run_id": run_id,
28                "deduplicate": true,
29                "items": [
30                    {
31                        "item_id": "release-checkpoint-1",
32                        "text": "Release checkpoint: API gateway patch r2.14 reduced 95th percentile response time from 410ms to 260ms during canary in us-east.",
33                        "metadata_json": "{\"source\":\"release-engineering\",\"region\":\"us-east\",\"team\":\"gateway\"}",
34                        "source": "rust-sdk-batch",
35                        "embedding": [],
36                    },
37                    {
38                        "item_id": "risk-register-2",
39                        "text": "Risk register update: vendor dependency for payment retries has elevated failure probability after certificate chain rotation.",
40                        "metadata_json": "{\"source\":\"risk-office\",\"domain\":\"payments\",\"severity\":\"high\"}",
41                        "source": "rust-sdk-batch",
42                        "embedding": [],
43                    },
44                    {
45                        "item_id": "warehouse-health-3",
46                        "text": "Warehouse health: lane C7 temperature deviation stabilized after recalibration; quality hold can be released after two clean cycles.",
47                        "metadata_json": "{\"source\":\"iot-telemetry\",\"facility\":\"atl-03\",\"lane\":\"C7\"}",
48                        "source": "rust-sdk-batch",
49                        "embedding": [],
50                    },
51                    {
52                        "item_id": "warehouse-health-3-duplicate",
53                        "text": "Warehouse health: lane C7 temperature deviation stabilized after recalibration; quality hold can be released after two clean cycles.",
54                        "metadata_json": "{\"source\":\"iot-telemetry\",\"facility\":\"atl-03\",\"lane\":\"C7\"}",
55                        "source": "rust-sdk-batch",
56                        "embedding": [],
57                    }
58                ]
59            }))
60            .await?;
61
62        let count = response.get("count").and_then(Value::as_u64).unwrap_or(0);
63        let node_ids = response
64            .get("node_ids")
65            .and_then(Value::as_array)
66            .cloned()
67            .unwrap_or_default();
68        let item_results = response
69            .get("item_results")
70            .and_then(Value::as_array)
71            .cloned()
72            .unwrap_or_default();
73
74        require(
75            count >= 3,
76            format!("expected at least 3 inserted items: {response}"),
77        )?;
78        require(
79            node_ids.len() == count as usize,
80            format!("node_ids mismatch: {response}"),
81        )?;
82        require(
83            item_results.len() == 4,
84            format!("item_results mismatch: {response}"),
85        )?;
86
87        let successful_items = item_results
88            .iter()
89            .filter(|item| {
90                item.get("success")
91                    .and_then(Value::as_bool)
92                    .unwrap_or(false)
93            })
94            .count();
95
96        require(
97            successful_items >= 3,
98            format!("expected at least 3 successful item inserts: {response}"),
99        )?;
100
101        let query = client
102            .control
103            .query(json!({
104                "run_id": run_id,
105                "query": "summarize release, risk, and warehouse updates for this run",
106                "schema": "",
107                "mode": "agent_routed",
108                "direct_lane": "semantic_search",
109                "include_linked_runs": false,
110                "limit": 6,
111                "embedding": [],
112            }))
113            .await?;
114
115        require(
116            query.get("mode").and_then(Value::as_str) == Some("agent_routed"),
117            format!("unexpected query mode: {query}"),
118        )?;
119        require(
120            query
121                .get("final_answer")
122                .and_then(Value::as_str)
123                .map(str::trim)
124                .map(|value| !value.is_empty())
125                .unwrap_or(false),
126            format!("final_answer should be non-empty: {query}"),
127        )?;
128        let evidence_count = query
129            .get("evidence")
130            .and_then(Value::as_array)
131            .map(|items| items.len())
132            .unwrap_or(0);
133        require(
134            evidence_count > 0,
135            format!("expected non-empty evidence: {query}"),
136        )?;
137
138        metrics = json!({
139            "run_id": run_id,
140            "count": count,
141            "successful_items": successful_items,
142            "node_ids": node_ids,
143            "evidence_count": evidence_count,
144        });
145
146        Ok::<(), Box<dyn Error>>(())
147    }
148    .await;
149
150    if let Err(err) = scenario {
151        passed = false;
152        detail = err.to_string();
153    }
154
155    let cleanup_ok = cleanup_run(&client, &run_id).await;
156    if !cleanup_ok {
157        passed = false;
158        detail = format!("{detail} | cleanup failures");
159    }
160
161    print_summary(
162        name,
163        passed,
164        &detail,
165        &metrics,
166        started.elapsed().as_secs_f64(),
167        cleanup_ok,
168    );
169
170    if passed {
171        Ok(())
172    } else {
173        Err(boxed_error(detail))
174    }
175}