1use serde::Serialize;
2use std::collections::HashMap;
3use trace_weft_core::{
4 BlobRef, CapturePolicy, CostEstimate, RunId, SpanId, SpanRecord, SpanStatus, TokenUsage,
5 TraceId, TraceWeftSpanKind,
6};
7use uuid::Uuid;
8
9pub struct SpanBuilder {
10 pub span: SpanRecord,
11 pending_input_ref: Option<PendingCapture>,
12 pending_output_ref: Option<PendingCapture>,
13}
14
15struct PendingCapture {
16 label: String,
17 value: serde_json::Value,
18}
19
20impl SpanBuilder {
21 pub fn new(kind: TraceWeftSpanKind, name: impl Into<String>) -> Self {
22 let now = std::time::SystemTime::now()
23 .duration_since(std::time::UNIX_EPOCH)
24 .unwrap_or_default()
25 .as_millis() as u64;
26
27 Self {
28 span: SpanRecord {
29 trace_id: TraceId(Uuid::now_v7()),
30 span_id: SpanId(Uuid::now_v7()),
31 parent_span_id: None,
32 run_id: RunId(Uuid::now_v7()),
33 session_id: None,
34 user_id_hash: None,
35 project_id: None,
36 span_kind: kind,
37 name: name.into(),
38 start_time: now,
39 end_time: None,
40 status: SpanStatus::InProgress,
41 status_message: None,
42 error_type: None,
43 error_message_redacted: None,
44 attributes: HashMap::new(),
45 otel_attributes: HashMap::new(),
46 openinference_attributes: HashMap::new(),
47 memory_state: None,
48 input_ref: None,
49 output_ref: None,
50 prompt_template_id: None,
51 prompt_version: None,
52 model_provider: None,
53 model_name: None,
54 tool_name: None,
55 tool_schema_hash: None,
56 retrieval_query_hash: None,
57 retrieved_document_refs: vec![],
58 token_usage: None,
59 cost_estimate: None,
60 latency_ms: None,
61 retry_count: None,
62 cache_hit: None,
63 redaction_policy: CapturePolicy::MetadataOnly,
64 schema_version: "1.0".to_string(),
65 },
66 pending_input_ref: None,
67 pending_output_ref: None,
68 }
69 }
70
71 pub fn provider(mut self, provider: impl Into<String>) -> Self {
72 self.span.model_provider = Some(provider.into());
73 self
74 }
75
76 pub fn model(mut self, model: impl Into<String>) -> Self {
77 self.span.model_name = Some(model.into());
78 self
79 }
80
81 pub fn prompt_version(mut self, version: impl Into<String>) -> Self {
82 self.span.prompt_version = Some(version.into());
83 self
84 }
85
86 pub fn tool_name(mut self, tool: impl Into<String>) -> Self {
87 self.span.tool_name = Some(tool.into());
88 self
89 }
90
91 pub fn input_ref<T: Serialize>(mut self, label: impl Into<String>, value: &T) -> Self {
98 self.pending_input_ref = Some(PendingCapture {
99 label: label.into(),
100 value: serde_json::to_value(value).unwrap_or(serde_json::Value::Null),
101 });
102 self
103 }
104
105 pub fn output_ref<T: Serialize>(mut self, label: impl Into<String>, value: &T) -> Self {
111 self.pending_output_ref = Some(PendingCapture {
112 label: label.into(),
113 value: serde_json::to_value(value).unwrap_or(serde_json::Value::Null),
114 });
115 self
116 }
117
118 pub fn input_blob_ref(mut self, blob_ref: BlobRef) -> Self {
120 self.span.input_ref = Some(blob_ref);
121 self
122 }
123
124 pub fn output_blob_ref(mut self, blob_ref: BlobRef) -> Self {
126 self.span.output_ref = Some(blob_ref);
127 self
128 }
129
130 pub fn token_usage(mut self, usage: TokenUsage) -> Self {
131 self.span.token_usage = Some(usage);
132 self
133 }
134
135 pub fn cost(mut self, cost: CostEstimate) -> Self {
136 self.span.cost_estimate = Some(cost);
137 self
138 }
139
140 pub fn cache_hit(mut self, hit: bool) -> Self {
141 self.span.cache_hit = Some(hit);
142 self
143 }
144
145 pub fn retrieval(mut self, query_hash: impl Into<String>, doc_refs: Vec<BlobRef>) -> Self {
147 self.span.retrieval_query_hash = Some(query_hash.into());
148 self.span.retrieved_document_refs = doc_refs;
149 self
150 }
151
152 pub fn attribute(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
154 self.span.attributes.insert(key.into(), value);
155 self
156 }
157
158 pub fn attributes(mut self, attrs: HashMap<String, serde_json::Value>) -> Self {
160 self.span.attributes.extend(attrs);
161 self
162 }
163
164 pub fn with_parent(mut self, trace_id: TraceId, run_id: RunId, parent_id: SpanId) -> Self {
165 self.span.trace_id = trace_id;
166 self.span.run_id = run_id;
167 self.span.parent_span_id = Some(parent_id);
168 self
169 }
170
171 pub async fn wait_for_approval(mut self) -> Result<crate::hitl::HitlResponse, String> {
172 crate::context::link_to_ambient(&mut self.span);
173 self.span.redaction_policy = crate::capture_policy();
174 self.capture_pending_refs().await;
175 let span_id = self.span.span_id.0.to_string();
176 self.span.status = SpanStatus::PendingApproval;
177
178 let rx = crate::hitl::register_approval(span_id);
179
180 crate::record_span(self.span.clone()).await;
182
183 match rx.await {
185 Ok(response) => {
186 self.span.end_time = Some(
189 std::time::SystemTime::now()
190 .duration_since(std::time::UNIX_EPOCH)
191 .unwrap_or_default()
192 .as_millis() as u64,
193 );
194 self.span.latency_ms = Some(self.span.end_time.unwrap() - self.span.start_time);
195 self.span.status = SpanStatus::Ok;
196 crate::record_span(self.span).await;
197 Ok(response)
198 }
199 Err(_) => Err("Hitl approval channel closed unexpectedly".to_string()),
200 }
201 }
202
203 pub async fn run<F, Fut, T, E>(mut self, f: F) -> Result<T, E>
204 where
205 F: FnOnce() -> Fut,
206 Fut: std::future::Future<Output = Result<T, E>>,
207 E: std::fmt::Debug + std::fmt::Display + 'static,
208 T: serde::de::DeserializeOwned,
209 {
210 self.span.redaction_policy = crate::capture_policy();
211 self.capture_pending_refs().await;
212 let mut span = self.span;
213 crate::context::link_to_ambient(&mut span);
214
215 if let Some(mocked) =
217 crate::replay::get_mocked_output(&span.span_id.0.to_string(), &span.name)
218 {
219 span.end_time = Some(span.start_time);
220 span.latency_ms = Some(0);
221 span.status = SpanStatus::Ok;
222 span.attributes
223 .insert("replayed".to_string(), serde_json::json!(true));
224 crate::record_span(span.clone()).await;
225
226 if let Ok(value) = serde_json::from_value::<T>(mocked) {
227 return Ok(value);
228 }
229 }
230
231 let ctx = crate::context::SpanContext::of(&span);
233 let result = crate::context::scope_current(ctx, f()).await;
234 span.end_time = Some(
235 std::time::SystemTime::now()
236 .duration_since(std::time::UNIX_EPOCH)
237 .unwrap_or_default()
238 .as_millis() as u64,
239 );
240 span.latency_ms = Some(span.end_time.unwrap() - span.start_time);
241
242 match &result {
243 Ok(_) => {
244 span.status = SpanStatus::Ok;
245 }
246 Err(e) => {
247 span.status = SpanStatus::Error;
248 span.error_type = Some(std::any::type_name::<E>().to_string());
249 span.error_message_redacted =
250 Some(crate::redact_text(&e.to_string()).redacted_text);
251 }
252 }
253
254 crate::record_span(span).await;
255
256 result
257 }
258
259 pub async fn run_infallible<F, Fut, T>(mut self, f: F) -> T
263 where
264 F: FnOnce() -> Fut,
265 Fut: std::future::Future<Output = T>,
266 {
267 self.span.redaction_policy = crate::capture_policy();
268 self.capture_pending_refs().await;
269 let mut span = self.span;
270 crate::context::link_to_ambient(&mut span);
271
272 let ctx = crate::context::SpanContext::of(&span);
273 let result = crate::context::scope_current(ctx, f()).await;
274
275 span.end_time = Some(
276 std::time::SystemTime::now()
277 .duration_since(std::time::UNIX_EPOCH)
278 .unwrap_or_default()
279 .as_millis() as u64,
280 );
281 span.latency_ms = Some(span.end_time.unwrap() - span.start_time);
282 span.status = SpanStatus::Ok;
283 crate::record_span(span).await;
284
285 result
286 }
287
288 async fn capture_pending_refs(&mut self) {
289 if self.span.input_ref.is_none()
290 && let Some(pending) = self.pending_input_ref.take()
291 {
292 self.span.input_ref = capture_labeled_json(pending).await;
293 }
294 if self.span.output_ref.is_none()
295 && let Some(pending) = self.pending_output_ref.take()
296 {
297 self.span.output_ref = capture_labeled_json(pending).await;
298 }
299 }
300}
301
302async fn capture_labeled_json(pending: PendingCapture) -> Option<BlobRef> {
303 let mut object = serde_json::Map::new();
304 object.insert(pending.label, pending.value);
305 crate::capture_json("application/json", serde_json::Value::Object(object)).await
306}
307
308pub fn llm_call(name: impl Into<String>) -> SpanBuilder {
309 SpanBuilder::new(TraceWeftSpanKind::LlmCall, name)
310}
311
312pub fn tool(name: impl Into<String>) -> SpanBuilder {
313 SpanBuilder::new(TraceWeftSpanKind::Tool, name)
314}
315
316pub fn agent(name: impl Into<String>) -> SpanBuilder {
317 SpanBuilder::new(TraceWeftSpanKind::Agent, name)
318}