Skip to main content

aster/telemetry/
tracker.rs

1//! 遥测追踪器
2
3use super::config::*;
4use super::sanitizer::*;
5use super::types::*;
6use parking_lot::RwLock;
7use sha2::{Digest, Sha256};
8use std::collections::HashMap;
9use std::fs::{self, OpenOptions};
10use std::io::{BufRead, BufReader, Write};
11use std::sync::Arc;
12use std::time::{SystemTime, UNIX_EPOCH};
13use tracing::warn;
14
15/// 遥测追踪器
16pub struct TelemetryTracker {
17    config: RwLock<TelemetryConfig>,
18    anonymous_id: String,
19    current_session: RwLock<Option<SessionMetrics>>,
20    event_queue: RwLock<Vec<TelemetryEvent>>,
21}
22
23impl TelemetryTracker {
24    /// 创建新的追踪器
25    pub fn new() -> Self {
26        let config = load_config();
27        let anonymous_id = get_or_create_anonymous_id();
28
29        // 确保目录存在
30        let dir = get_telemetry_dir();
31        if !dir.exists() {
32            let _ = fs::create_dir_all(&dir);
33        }
34
35        Self {
36            config: RwLock::new(config),
37            anonymous_id,
38            current_session: RwLock::new(None),
39            event_queue: RwLock::new(Vec::new()),
40        }
41    }
42
43    /// 检查是否启用
44    pub fn is_enabled(&self) -> bool {
45        self.config.read().enabled
46    }
47
48    /// 开始新会话
49    pub fn start_session(&self, session_id: &str, model: &str) {
50        if !self.is_enabled() {
51            return;
52        }
53
54        let session = SessionMetrics {
55            session_id: session_id.to_string(),
56            start_time: current_timestamp(),
57            model: model.to_string(),
58            ..Default::default()
59        };
60
61        *self.current_session.write() = Some(session);
62        self.track_event(
63            "session_start",
64            HashMap::from([("model".to_string(), serde_json::json!(model))]),
65        );
66    }
67
68    /// 结束会话
69    pub fn end_session(&self) {
70        if !self.is_enabled() {
71            return;
72        }
73
74        let mut session_guard = self.current_session.write();
75        if let Some(ref mut session) = *session_guard {
76            session.end_time = Some(current_timestamp());
77
78            let duration = session.end_time.unwrap() - session.start_time;
79            self.track_event(
80                "session_end",
81                HashMap::from([
82                    ("duration".to_string(), serde_json::json!(duration)),
83                    (
84                        "message_count".to_string(),
85                        serde_json::json!(session.message_count),
86                    ),
87                    (
88                        "token_usage".to_string(),
89                        serde_json::to_value(&session.token_usage).unwrap(),
90                    ),
91                    (
92                        "estimated_cost".to_string(),
93                        serde_json::json!(session.estimated_cost),
94                    ),
95                ]),
96            );
97
98            self.update_aggregate_metrics(session);
99        }
100
101        *session_guard = None;
102    }
103
104    /// 跟踪事件
105    pub fn track_event(&self, event_type: &str, data: HashMap<String, serde_json::Value>) {
106        if !self.is_enabled() {
107            return;
108        }
109
110        let sanitized_data = sanitize_map(&data);
111        let session_id = self
112            .current_session
113            .read()
114            .as_ref()
115            .map(|s| s.session_id.clone())
116            .unwrap_or_else(|| "unknown".to_string());
117
118        let event = TelemetryEvent {
119            event_type: event_type.to_string(),
120            timestamp: current_timestamp(),
121            session_id,
122            anonymous_id: self.anonymous_id.clone(),
123            data: sanitized_data,
124            version: Some(env!("CARGO_PKG_VERSION").to_string()),
125            platform: Some(std::env::consts::OS.to_string()),
126        };
127
128        // 追加到事件文件
129        if let Err(e) = append_to_jsonl(&get_events_file(), &event) {
130            warn!("Failed to write event: {}", e);
131        }
132
133        // 添加到队列
134        let config = self.config.read();
135        if config.batch_upload {
136            let mut queue = self.event_queue.write();
137            queue.push(event);
138            if queue.len() > MAX_QUEUE_SIZE {
139                queue.remove(0);
140            }
141        }
142    }
143
144    /// 跟踪消息
145    pub fn track_message(&self, role: &str) {
146        if !self.is_enabled() {
147            return;
148        }
149
150        if let Some(ref mut session) = *self.current_session.write() {
151            session.message_count += 1;
152        }
153
154        self.track_event(
155            "message",
156            HashMap::from([("role".to_string(), serde_json::json!(role))]),
157        );
158    }
159
160    /// 跟踪工具调用
161    pub fn track_tool_call(&self, tool_name: &str, success: bool, duration: u64) {
162        if !self.is_enabled() {
163            return;
164        }
165
166        if let Some(ref mut session) = *self.current_session.write() {
167            *session.tool_calls.entry(tool_name.to_string()).or_insert(0) += 1;
168            if !success {
169                session.errors += 1;
170            }
171        }
172
173        self.track_event(
174            "tool_call",
175            HashMap::from([
176                ("tool_name".to_string(), serde_json::json!(tool_name)),
177                ("success".to_string(), serde_json::json!(success)),
178                ("duration".to_string(), serde_json::json!(duration)),
179            ]),
180        );
181
182        if self.config.read().performance_tracking {
183            self.track_performance(tool_name, duration, success, None);
184        }
185    }
186
187    /// 跟踪命令使用
188    pub fn track_command(&self, command_name: &str, success: bool, duration: u64) {
189        if !self.is_enabled() {
190            return;
191        }
192
193        self.track_event(
194            "command_use",
195            HashMap::from([
196                ("command_name".to_string(), serde_json::json!(command_name)),
197                ("success".to_string(), serde_json::json!(success)),
198                ("duration".to_string(), serde_json::json!(duration)),
199            ]),
200        );
201
202        if self.config.read().performance_tracking {
203            self.track_performance(
204                &format!("command:{}", command_name),
205                duration,
206                success,
207                None,
208            );
209        }
210    }
211
212    /// 跟踪 token 使用
213    pub fn track_token_usage(&self, input: u64, output: u64, cost: f64) {
214        if !self.is_enabled() {
215            return;
216        }
217
218        if let Some(ref mut session) = *self.current_session.write() {
219            session.token_usage.input += input;
220            session.token_usage.output += output;
221            session.token_usage.total += input + output;
222            session.estimated_cost += cost;
223        }
224
225        self.track_event(
226            "token_usage",
227            HashMap::from([
228                ("input".to_string(), serde_json::json!(input)),
229                ("output".to_string(), serde_json::json!(output)),
230                ("cost".to_string(), serde_json::json!(cost)),
231            ]),
232        );
233    }
234
235    /// 跟踪错误
236    pub fn track_error(&self, error: &str, context: Option<HashMap<String, serde_json::Value>>) {
237        if !self.is_enabled() {
238            return;
239        }
240
241        if let Some(ref mut session) = *self.current_session.write() {
242            session.errors += 1;
243        }
244
245        let mut data = HashMap::from([("error".to_string(), serde_json::json!(error))]);
246        if let Some(ctx) = context {
247            data.extend(ctx);
248        }
249
250        self.track_event("error", data);
251    }
252
253    /// 跟踪性能指标
254    pub fn track_performance(
255        &self,
256        operation: &str,
257        duration: u64,
258        success: bool,
259        metadata: Option<HashMap<String, serde_json::Value>>,
260    ) {
261        if !self.is_enabled() || !self.config.read().performance_tracking {
262            return;
263        }
264
265        let sanitized_metadata = metadata.map(|m| sanitize_map(&m));
266
267        let metric = PerformanceMetric {
268            operation: operation.to_string(),
269            duration,
270            timestamp: current_timestamp(),
271            success,
272            metadata: sanitized_metadata,
273        };
274
275        if let Err(e) = append_to_jsonl(&get_performance_file(), &metric) {
276            warn!("Failed to write performance metric: {}", e);
277        }
278    }
279
280    /// 跟踪详细错误报告
281    pub fn track_error_report(
282        &self,
283        error_type: &str,
284        error_message: &str,
285        stack: Option<String>,
286        context: HashMap<String, serde_json::Value>,
287    ) {
288        if !self.is_enabled() || !self.config.read().error_reporting {
289            return;
290        }
291
292        let sanitized_context = sanitize_map(&context);
293        let session_id = self
294            .current_session
295            .read()
296            .as_ref()
297            .map(|s| s.session_id.clone())
298            .unwrap_or_else(|| "unknown".to_string());
299
300        let report = ErrorReport {
301            error_type: error_type.to_string(),
302            error_message: sanitize_string(error_message),
303            stack: stack.map(|s| sanitize_string(&s)),
304            context: sanitized_context,
305            timestamp: current_timestamp(),
306            session_id,
307            anonymous_id: self.anonymous_id.clone(),
308        };
309
310        if let Err(e) = append_to_jsonl(&get_errors_file(), &report) {
311            warn!("Failed to write error report: {}", e);
312        }
313
314        self.track_error(error_type, None);
315    }
316
317    /// 更新聚合指标
318    fn update_aggregate_metrics(&self, session: &SessionMetrics) {
319        let mut metrics = load_aggregate_metrics().unwrap_or_default();
320
321        metrics.total_sessions += 1;
322        metrics.total_messages += session.message_count;
323        metrics.total_tokens += session.token_usage.total;
324        metrics.total_cost += session.estimated_cost;
325        metrics.total_errors += session.errors;
326
327        for (tool, count) in &session.tool_calls {
328            *metrics.tool_usage.entry(tool.clone()).or_insert(0) += count;
329        }
330
331        *metrics
332            .model_usage
333            .entry(session.model.clone())
334            .or_insert(0) += 1;
335
336        let duration = session.end_time.unwrap_or(current_timestamp()) - session.start_time;
337        metrics.average_session_duration = (metrics.average_session_duration
338            * (metrics.total_sessions - 1) as f64
339            + duration as f64)
340            / metrics.total_sessions as f64;
341
342        metrics.last_updated = current_timestamp();
343
344        if let Err(e) = save_aggregate_metrics(&metrics) {
345            warn!("Failed to save aggregate metrics: {}", e);
346        }
347    }
348
349    /// 获取聚合指标
350    pub fn get_metrics(&self) -> Option<AggregateMetrics> {
351        load_aggregate_metrics()
352    }
353
354    /// 获取当前会话指标
355    pub fn get_current_session(&self) -> Option<SessionMetrics> {
356        self.current_session.read().clone()
357    }
358
359    /// 获取匿名 ID
360    pub fn get_anonymous_id(&self) -> &str {
361        &self.anonymous_id
362    }
363
364    /// 启用遥测
365    pub fn enable(&self) {
366        if is_telemetry_disabled() {
367            warn!("Telemetry disabled via environment variable");
368            return;
369        }
370        self.config.write().enabled = true;
371        self.save_config();
372    }
373
374    /// 禁用遥测
375    pub fn disable(&self) {
376        self.config.write().enabled = false;
377        self.save_config();
378    }
379
380    /// 启用错误报告
381    pub fn enable_error_reporting(&self) {
382        self.config.write().error_reporting = true;
383        self.save_config();
384    }
385
386    /// 禁用错误报告
387    pub fn disable_error_reporting(&self) {
388        self.config.write().error_reporting = false;
389        self.save_config();
390    }
391
392    /// 启用性能追踪
393    pub fn enable_performance_tracking(&self) {
394        self.config.write().performance_tracking = true;
395        self.save_config();
396    }
397
398    /// 禁用性能追踪
399    pub fn disable_performance_tracking(&self) {
400        self.config.write().performance_tracking = false;
401        self.save_config();
402    }
403
404    /// 保存配置
405    fn save_config(&self) {
406        let config = self.config.read().clone();
407        if let Err(e) = save_telemetry_config(&config) {
408            warn!("Failed to save telemetry config: {}", e);
409        }
410    }
411
412    /// 清除所有遥测数据
413    pub fn clear_data(&self) {
414        let files = [
415            get_metrics_file(),
416            get_events_file(),
417            get_errors_file(),
418            get_performance_file(),
419            get_queue_file(),
420        ];
421
422        for file in &files {
423            if file.exists() {
424                let _ = fs::remove_file(file);
425            }
426        }
427    }
428}
429
430impl Default for TelemetryTracker {
431    fn default() -> Self {
432        Self::new()
433    }
434}
435
436// 辅助函数
437
438/// 获取当前时间戳(毫秒)
439fn current_timestamp() -> u64 {
440    SystemTime::now()
441        .duration_since(UNIX_EPOCH)
442        .unwrap_or_default()
443        .as_millis() as u64
444}
445
446/// 获取或创建匿名 ID
447fn get_or_create_anonymous_id() -> String {
448    let id_file = get_anonymous_id_file();
449
450    if id_file.exists() {
451        if let Ok(id) = fs::read_to_string(&id_file) {
452            return id.trim().to_string();
453        }
454    }
455
456    // 生成新的匿名 ID
457    let machine_info = format!(
458        "{}|{}|{}|{}",
459        hostname::get()
460            .map(|h| h.to_string_lossy().to_string())
461            .unwrap_or_default(),
462        std::env::consts::OS,
463        std::env::consts::ARCH,
464        dirs::home_dir()
465            .map(|p| p.to_string_lossy().to_string())
466            .unwrap_or_default(),
467    );
468
469    let mut hasher = Sha256::new();
470    hasher.update(machine_info.as_bytes());
471    let hash = format!("{:x}", hasher.finalize());
472    let id = format!("anon_{}", hash.get(..32).unwrap_or(&hash));
473
474    // 确保目录存在
475    if let Some(parent) = id_file.parent() {
476        let _ = fs::create_dir_all(parent);
477    }
478
479    let _ = fs::write(&id_file, &id);
480    id
481}
482
483/// 加载配置
484fn load_config() -> TelemetryConfig {
485    let config_file = get_config_file();
486    if config_file.exists() {
487        if let Ok(content) = fs::read_to_string(&config_file) {
488            if let Ok(config) = serde_json::from_str(&content) {
489                return config;
490            }
491        }
492    }
493    TelemetryConfig::default()
494}
495
496/// 保存配置
497fn save_telemetry_config(config: &TelemetryConfig) -> Result<(), String> {
498    let config_file = get_config_file();
499    if let Some(parent) = config_file.parent() {
500        fs::create_dir_all(parent).map_err(|e| e.to_string())?;
501    }
502    let content = serde_json::to_string_pretty(config).map_err(|e| e.to_string())?;
503    fs::write(&config_file, content).map_err(|e| e.to_string())
504}
505
506/// 追加到 JSONL 文件
507fn append_to_jsonl<T: serde::Serialize>(path: &std::path::Path, data: &T) -> Result<(), String> {
508    if let Some(parent) = path.parent() {
509        fs::create_dir_all(parent).map_err(|e| e.to_string())?;
510    }
511
512    let mut file = OpenOptions::new()
513        .create(true)
514        .append(true)
515        .open(path)
516        .map_err(|e| e.to_string())?;
517
518    let json = serde_json::to_string(data).map_err(|e| e.to_string())?;
519    writeln!(file, "{}", json).map_err(|e| e.to_string())?;
520
521    // 限制文件大小
522    trim_jsonl_file(path, MAX_EVENTS);
523
524    Ok(())
525}
526
527/// 限制 JSONL 文件行数
528fn trim_jsonl_file(path: &std::path::Path, max_lines: usize) {
529    if !path.exists() {
530        return;
531    }
532
533    let file = match fs::File::open(path) {
534        Ok(f) => f,
535        Err(_) => return,
536    };
537
538    let reader = BufReader::new(file);
539    let lines: Vec<String> = reader.lines().map_while(Result::ok).collect();
540
541    if lines.len() > max_lines {
542        let trimmed: Vec<&str> = lines
543            .iter()
544            .skip(lines.len() - max_lines)
545            .map(|s| s.as_str())
546            .collect();
547        let _ = fs::write(path, trimmed.join("\n") + "\n");
548    }
549}
550
551/// 加载聚合指标
552fn load_aggregate_metrics() -> Option<AggregateMetrics> {
553    let metrics_file = get_metrics_file();
554    if metrics_file.exists() {
555        if let Ok(content) = fs::read_to_string(&metrics_file) {
556            if let Ok(metrics) = serde_json::from_str(&content) {
557                return Some(metrics);
558            }
559        }
560    }
561    None
562}
563
564/// 保存聚合指标
565fn save_aggregate_metrics(metrics: &AggregateMetrics) -> Result<(), String> {
566    let metrics_file = get_metrics_file();
567    if let Some(parent) = metrics_file.parent() {
568        fs::create_dir_all(parent).map_err(|e| e.to_string())?;
569    }
570    let content = serde_json::to_string_pretty(metrics).map_err(|e| e.to_string())?;
571    fs::write(&metrics_file, content).map_err(|e| e.to_string())
572}
573
574/// 全局追踪器
575static GLOBAL_TRACKER: once_cell::sync::Lazy<Arc<TelemetryTracker>> =
576    once_cell::sync::Lazy::new(|| Arc::new(TelemetryTracker::new()));
577
578/// 获取全局追踪器
579pub fn global_tracker() -> Arc<TelemetryTracker> {
580    GLOBAL_TRACKER.clone()
581}