Skip to main content

surreal_client/
client.rs

1use std::sync::Arc;
2
3use ciborium::Value as CborValue;
4use serde_json::{Value, json};
5
6use crate::{Engine, RecordId, RecordRange, Result, SessionState, SurrealError, Table};
7
8pub struct SurrealClient {
9    engine: Arc<tokio::sync::Mutex<Box<dyn Engine>>>,
10    session: SessionState,
11    incremental_id: Arc<std::sync::atomic::AtomicU64>,
12    debug: bool,
13}
14
15impl Clone for SurrealClient {
16    /// Clone the client - creates a new client instance sharing the same engine and session
17    fn clone(&self) -> Self {
18        Self {
19            engine: self.engine.clone(),
20            session: self.session.clone(),
21            incremental_id: self.incremental_id.clone(),
22            debug: self.debug,
23        }
24    }
25}
26
27impl SurrealClient {
28    /// Create a new SurrealDB instance with the given engine and optional namespace/database
29    pub fn new(
30        engine: Box<dyn Engine>,
31        namespace: Option<String>,
32        database: Option<String>,
33    ) -> Self {
34        let mut session = SessionState::new();
35        session.set_target(namespace, database);
36
37        Self {
38            engine: Arc::new(tokio::sync::Mutex::new(engine)),
39            session,
40            incremental_id: Arc::new(std::sync::atomic::AtomicU64::new(0)),
41            debug: false,
42        }
43    }
44
45    /// Enable debug mode to log queries
46    pub fn with_debug(mut self, enabled: bool) -> Self {
47        self.debug = enabled;
48        self
49    }
50
51    /// Check if debug mode is enabled
52    pub fn is_debug(&self) -> bool {
53        self.debug
54    }
55
56    /// Set a parameter for the session
57    pub async fn let_var(&mut self, key: &str, value: Value) -> Result<()> {
58        let mut engine = self.engine.lock().await;
59
60        let params = json!([key, value]);
61
62        engine.send_message("let", params).await?;
63
64        // Store the variable in the session
65        self.session.set_param(key.to_string(), value);
66
67        Ok(())
68    }
69
70    /// Unset a parameter from the session
71    pub async fn unset(&mut self, key: &str) -> Result<()> {
72        let mut engine = self.engine.lock().await;
73
74        let params = json!([key]);
75
76        engine.send_message("unset", params).await?;
77
78        // Remove the variable from the session
79        self.session.unset_param(key);
80        Ok(())
81    }
82
83    /// Create a record in the database
84    pub async fn create(&self, resource: &str, data: Option<Value>) -> Result<Value> {
85        let mut engine = self.engine.lock().await;
86
87        let params = if let Some(data) = data {
88            json!([resource, data])
89        } else {
90            json!([resource])
91        };
92
93        engine.send_message("create", params).await
94    }
95
96    /// Select records from the database
97    pub async fn select(&self, resource: &str) -> Result<Value> {
98        let mut engine = self.engine.lock().await;
99
100        let params = json!([resource]);
101
102        engine.send_message("select", params).await
103    }
104
105    /// Select all records from a table
106    pub async fn select_all(&self, table: Table) -> Result<Value> {
107        self.select(&table.to_string()).await
108    }
109
110    /// Select a specific record by ID
111    pub async fn select_record(&self, record: RecordId) -> Result<Value> {
112        self.select(&record.to_string()).await
113    }
114
115    /// Select a range of records
116    pub async fn select_range(&self, range: RecordRange) -> Result<Value> {
117        self.select(&range.to_string()).await
118    }
119
120    /// Update records in the database
121    pub async fn update(&self, resource: &str, data: Option<Value>) -> Result<Value> {
122        let mut engine = self.engine.lock().await;
123
124        let params = if let Some(data) = data {
125            json!([resource, data])
126        } else {
127            json!([resource])
128        };
129
130        engine.send_message("update", params).await
131    }
132
133    /// Update a specific record by ID
134    pub async fn update_record(&self, record: RecordId, data: Value) -> Result<Value> {
135        self.update(&record.to_string(), Some(data)).await
136    }
137
138    /// Update all records in a table
139    pub async fn update_all(&self, table: Table, data: Value) -> Result<Value> {
140        self.update(&table.to_string(), Some(data)).await
141    }
142
143    /// Upsert (insert or update) records in the database
144    pub async fn upsert(&self, resource: &str, data: Option<Value>) -> Result<Value> {
145        let mut engine = self.engine.lock().await;
146
147        let params = if let Some(data) = data {
148            json!([resource, data])
149        } else {
150            json!([resource])
151        };
152
153        engine.send_message("upsert", params).await
154    }
155
156    /// Upsert a specific record by ID
157    pub async fn upsert_record(&self, record: RecordId, data: Value) -> Result<Value> {
158        self.upsert(&record.to_string(), Some(data)).await
159    }
160
161    /// Merge data into records in the database
162    pub async fn merge(&self, resource: &str, data: Value) -> Result<Value> {
163        let mut engine = self.engine.lock().await;
164
165        let params = json!([resource, data]);
166
167        engine.send_message("merge", params).await
168    }
169
170    /// Merge data into a specific record by ID
171    pub async fn merge_record(&self, record: RecordId, data: Value) -> Result<Value> {
172        self.merge(&record.to_string(), data).await
173    }
174
175    /// Merge data into all records in a table
176    pub async fn merge_all(&self, table: Table, data: Value) -> Result<Value> {
177        self.merge(&table.to_string(), data).await
178    }
179
180    /// Apply JSON patches to records
181    /// Apply patches to records in the database
182    pub async fn patch(&self, resource: &str, patches: Vec<Value>) -> Result<Value> {
183        let mut engine = self.engine.lock().await;
184
185        let params = json!([resource, patches]);
186
187        engine.send_message("patch", params).await
188    }
189
190    /// Delete records from the database
191    pub async fn delete(&self, resource: &str) -> Result<Value> {
192        let mut engine = self.engine.lock().await;
193
194        let params = json!([resource]);
195
196        engine.send_message("delete", params).await
197    }
198
199    /// Delete a specific record by ID
200    pub async fn delete_record(&self, record: RecordId) -> Result<Value> {
201        self.delete(&record.to_string()).await
202    }
203
204    /// Delete all records from a table
205    pub async fn delete_all(&self, table: Table) -> Result<Value> {
206        self.delete(&table.to_string()).await
207    }
208
209    /// Insert records into the database
210    /// Insert data into a table
211    pub async fn insert(&self, table: &str, data: Value) -> Result<Value> {
212        let mut engine = self.engine.lock().await;
213
214        let params = json!([table, data]);
215
216        engine.send_message("insert", params).await
217    }
218
219    /// Insert multiple records
220    pub async fn insert_many(&self, table: Table, data: Vec<Value>) -> Result<Value> {
221        // TODO: add single test
222        self.insert(&table.to_string(), Value::Array(data)).await
223    }
224
225    /// Create a relation between records
226    pub async fn relate(
227        &self,
228        from: &str,
229        relation: &str,
230        to: &str,
231        data: Option<Value>,
232    ) -> Result<Value> {
233        let mut engine = self.engine.lock().await;
234
235        let params = if let Some(data) = data {
236            json!([from, relation, to, data])
237        } else {
238            json!([from, relation, to])
239        };
240
241        engine.send_message("relate", params).await
242    }
243
244    /// Create a relation between specific records
245    pub async fn relate_records(
246        &self,
247        from: RecordId,
248        relation: Table,
249        to: RecordId,
250        data: Option<Value>,
251    ) -> Result<Value> {
252        self.relate(
253            &from.to_string(),
254            &relation.to_string(),
255            &to.to_string(),
256            data,
257        )
258        .await
259    }
260
261    /// Run a stored function
262    pub async fn run(&self, func: &str, args: Option<Value>) -> Result<Value> {
263        let mut engine = self.engine.lock().await;
264
265        let params = if let Some(args) = args {
266            json!([func, args])
267        } else {
268            json!([func])
269        };
270
271        engine.send_message("run", params).await
272    }
273
274    /// Execute a custom SurrealQL query
275    pub async fn query(&self, sql: &str, variables: Option<Value>) -> Result<Value> {
276        if self.debug {
277            if let Some(ref vars) = variables {
278                println!("🔍 SQL: {}", sql);
279                println!(
280                    "📊 Params: {}",
281                    serde_json::to_string_pretty(vars).unwrap_or_default()
282                );
283            } else {
284                println!("🔍 SQL: {}", sql);
285            }
286        }
287
288        let mut engine = self.engine.lock().await;
289
290        let params = if let Some(vars) = variables {
291            json!([sql, vars])
292        } else {
293            json!([sql])
294        };
295
296        let response = engine.send_message("query", params).await?;
297
298        if self.debug {
299            // Check if response contains status field to determine icon
300            let icon = if let Value::Array(ref results) = response {
301                if results
302                    .iter()
303                    .any(|r| r.get("status").and_then(|s| s.as_str()) == Some("ERR"))
304                {
305                    "❌"
306                } else {
307                    "✅"
308                }
309            } else {
310                "✅"
311            };
312
313            println!(
314                "{} Response: {}",
315                icon,
316                serde_json::to_string_pretty(&response).unwrap_or_default()
317            );
318        }
319
320        // Handle the query response format
321        match response {
322            Value::Array(results) => {
323                // Return the results array directly
324                Ok(Value::Array(results))
325            }
326            other => Ok(other),
327        }
328    }
329
330    /// Get information about the current session
331    pub async fn info(&self) -> Result<Value> {
332        let mut engine = self.engine.lock().await;
333
334        let params = json!([]);
335
336        engine.send_message("info", params).await
337    }
338
339    /// Get the version of the SurrealDB instance
340    pub async fn version(&self) -> Result<String> {
341        let mut engine = self.engine.lock().await;
342
343        let params = json!([]);
344
345        let response = engine.send_message("version", params).await?;
346
347        match response {
348            Value::String(version) => Ok(version),
349            _ => Err(SurrealError::Protocol(
350                "Invalid version response format".to_string(),
351            )),
352        }
353    }
354
355    /// Close the connection
356    pub async fn close(self) -> Result<()> {
357        // Note: engine is moved here since we're taking ownership
358        // The session will be dropped automatically
359        // Engine trait doesn't have close method in minimal implementation
360        Ok(())
361    }
362
363    /// Import database content (HTTP only)
364    pub async fn import(&self, _content: &str, _username: &str, _password: &str) -> Result<Value> {
365        Err(SurrealError::Protocol(
366            "Import is not supported in minimal engine implementation".to_string(),
367        ))
368    }
369
370    /// Export database content (HTTP only)
371    pub async fn export(&self, _username: &str, _password: &str) -> Result<String> {
372        Err(SurrealError::Protocol(
373            "Export is not supported in minimal engine implementation".to_string(),
374        ))
375    }
376
377    /// Import ML model (HTTP only)
378    pub async fn import_ml(
379        &self,
380        _content: &str,
381        _username: Option<&str>,
382        _password: Option<&str>,
383    ) -> Result<Value> {
384        Err(SurrealError::Protocol(
385            "ML import is not supported in minimal engine implementation".to_string(),
386        ))
387    }
388
389    /// Export ML model (HTTP only)
390    pub async fn export_ml(
391        &self,
392        _name: &str,
393        _version: Option<&str>,
394        _username: Option<&str>,
395        _password: Option<&str>,
396    ) -> Result<String> {
397        Err(SurrealError::Protocol(
398            "ML export is not supported in minimal engine implementation".to_string(),
399        ))
400    }
401
402    /// Execute a custom SurrealQL query with CBOR parameters
403    pub async fn query_cbor(&self, sql: &str, variables: CborValue) -> Result<CborValue> {
404        let mut engine = self.engine.lock().await;
405
406        if self.debug {
407            println!("SQL: {}", sql);
408            println!("Params: {:?}", variables);
409        }
410
411        let params = CborValue::Array(vec![CborValue::Text(sql.to_string()), variables]);
412        let response = engine.send_message_cbor("query", params).await?;
413
414        if self.debug {
415            println!("✅ CBOR Response: {:?}", response);
416        }
417
418        Ok(response)
419    }
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425    use serde_json::json;
426
427    // Mock engine for testing
428    struct MockEngine;
429
430    #[async_trait::async_trait]
431    impl Engine for MockEngine {
432        async fn send_message(&mut self, _method: &str, _params: Value) -> Result<Value> {
433            Ok(Value::String("mock_response".to_string()))
434        }
435
436        async fn send_message_cbor(
437            &mut self,
438            _method: &str,
439            _params: CborValue,
440        ) -> Result<CborValue> {
441            Ok(CborValue::Text("mock_response".to_string()))
442        }
443    }
444
445    #[tokio::test]
446    async fn test_surrealdb_creation() {
447        let engine = Box::new(MockEngine);
448        let _client = SurrealClient::new(engine, None, None);
449    }
450
451    #[tokio::test]
452    async fn test_connect_and_operations() {
453        let engine = Box::new(MockEngine);
454        let client = SurrealClient::new(
455            engine,
456            Some("test_ns".to_string()),
457            Some("test_db".to_string()),
458        );
459
460        // Test basic operations
461        let result = client.select("user").await.unwrap();
462        assert_eq!(result, Value::String("mock_response".to_string()));
463
464        let result = client
465            .create("user", Some(json!({"name": "John"})))
466            .await
467            .unwrap();
468        assert_eq!(result, Value::String("mock_response".to_string()));
469    }
470
471    #[tokio::test]
472    async fn test_crud_operations() {
473        let engine = Box::new(MockEngine);
474        let client = SurrealClient::new(
475            engine,
476            Some("test_ns".to_string()),
477            Some("test_db".to_string()),
478        );
479
480        // Test Create
481        let create_result = client
482            .create("users", Some(json!({"name": "Alice", "age": 30})))
483            .await
484            .unwrap();
485        assert_eq!(create_result, Value::String("mock_response".to_string()));
486
487        // Test Read
488        let read_result = client.select("users").await.unwrap();
489        assert_eq!(read_result, Value::String("mock_response".to_string()));
490
491        // Test Update
492        let update_result = client
493            .update("users:alice", Some(json!({"age": 31})))
494            .await
495            .unwrap();
496        assert_eq!(update_result, Value::String("mock_response".to_string()));
497
498        // Test Delete
499        let delete_result = client.delete("users:alice").await.unwrap();
500        assert_eq!(delete_result, Value::String("mock_response".to_string()));
501
502        // Test Insert
503        let insert_result = client
504            .insert("users", json!({"name": "Bob", "age": 25}))
505            .await
506            .unwrap();
507        assert_eq!(insert_result, Value::String("mock_response".to_string()));
508
509        // Test Merge
510        let merge_result = client
511            .merge("users:bob", json!({"city": "New York"}))
512            .await
513            .unwrap();
514        assert_eq!(merge_result, Value::String("mock_response".to_string()));
515
516        // Test Upsert
517        let upsert_result = client
518            .upsert("users:charlie", Some(json!({"name": "Charlie", "age": 28})))
519            .await
520            .unwrap();
521        assert_eq!(upsert_result, Value::String("mock_response".to_string()));
522    }
523
524    // Removed test_http_engine as HttpEngine is not available in minimal implementation
525
526    #[tokio::test]
527    async fn test_bakery_queries_with_parameters() {
528        let engine = Box::new(MockEngine);
529        let mut client = SurrealClient::new(
530            engine,
531            Some("bakery".to_string()),
532            Some("inventory".to_string()),
533        );
534
535        // Set variables
536        client.let_var("min_stock", json!(10)).await.unwrap();
537
538        client.let_var("category", json!("bread")).await.unwrap();
539
540        // Test parameterized query
541        let query =
542            "SELECT * FROM products WHERE stock_level < $min_stock AND category = $category";
543        let result = client.query(query, None).await.unwrap();
544        assert_eq!(result, Value::String("mock_response".to_string()));
545
546        // Test query with inline parameters
547        let variables = json!({
548            "supplier": "FreshBake Co",
549            "min_price": 5.0
550        });
551
552        let query_with_params = "SELECT * FROM products WHERE supplier = $supplier AND price >= $min_price ORDER BY price DESC";
553        let result = client
554            .query(query_with_params, Some(variables))
555            .await
556            .unwrap();
557        assert_eq!(result, Value::String("mock_response".to_string()));
558
559        // Test complex aggregation query
560        let analytics_query = r#"
561            SELECT
562                category,
563                COUNT() as total_products,
564                SUM(stock_level) as total_stock,
565                AVG(price) as avg_price,
566                MAX(price) as max_price,
567                MIN(price) as min_price
568            FROM products
569            WHERE stock_level > 0
570            GROUP BY category
571            ORDER BY total_stock DESC
572        "#;
573
574        let result = client.query(analytics_query, None).await.unwrap();
575        assert_eq!(result, Value::String("mock_response".to_string()));
576
577        // Test relation query
578        let relation_query = r#"
579            SELECT *,
580                ->supplied_by->suppliers.* as supplier_info
581            FROM products
582            WHERE category = 'pastries'
583        "#;
584
585        let result = client.query(relation_query, None).await.unwrap();
586        assert_eq!(result, Value::String("mock_response".to_string()));
587
588        // Test time-based query
589        let time_query = r#"
590            SELECT *
591            FROM orders
592            WHERE created_at >= time::now() - 7d
593            ORDER BY created_at DESC
594            LIMIT 50
595        "#;
596
597        let result = client.query(time_query, None).await.unwrap();
598        assert_eq!(result, Value::String("mock_response".to_string()));
599
600        // Clean up variables
601        client.unset("min_stock").await.unwrap();
602        client.unset("category").await.unwrap();
603    }
604
605    #[tokio::test]
606    async fn test_complex_analytics_queries() {
607        let engine = Box::new(MockEngine);
608        let client = SurrealClient::new(
609            engine,
610            Some("analytics".to_string()),
611            Some("business".to_string()),
612        );
613
614        // Test revenue analysis
615        let revenue_query = r#"
616            SELECT
617                date::format(created_at, '%Y-%m') as month,
618                SUM(total_amount) as monthly_revenue,
619                COUNT() as order_count,
620                AVG(total_amount) as avg_order_value
621            FROM orders
622            WHERE created_at >= time::now() - 12mo
623            GROUP BY month
624            ORDER BY month DESC
625        "#;
626
627        let result = client.query(revenue_query, None).await.unwrap();
628        assert_eq!(result, Value::String("mock_response".to_string()));
629
630        // Test customer segmentation
631        let segmentation_query = r#"
632            SELECT
633                CASE
634                    WHEN total_spent >= 1000 THEN 'Premium'
635                    WHEN total_spent >= 500 THEN 'Regular'
636                    ELSE 'Basic'
637                END as segment,
638                COUNT() as customer_count,
639                AVG(total_spent) as avg_spent,
640                SUM(total_spent) as segment_revenue
641            FROM (
642                SELECT
643                    customer_id,
644                    SUM(total_amount) as total_spent
645                FROM orders
646                GROUP BY customer_id
647            ) as customer_totals
648            GROUP BY segment
649            ORDER BY avg_spent DESC
650        "#;
651
652        let result = client.query(segmentation_query, None).await.unwrap();
653        assert_eq!(result, Value::String("mock_response".to_string()));
654
655        // Test product performance with inventory correlation
656        let performance_query = r#"
657            SELECT
658                p.id,
659                p.name,
660                p.category,
661                COUNT(oi.id) as times_ordered,
662                SUM(oi.quantity) as total_quantity_sold,
663                SUM(oi.price * oi.quantity) as total_revenue,
664                p.stock_level as current_stock,
665                CASE
666                    WHEN p.stock_level = 0 THEN 'Out of Stock'
667                    WHEN p.stock_level < 10 THEN 'Low Stock'
668                    WHEN p.stock_level < 50 THEN 'Medium Stock'
669                    ELSE 'High Stock'
670                END as stock_status
671            FROM products p
672            LEFT JOIN order_items oi ON oi.product_id = p.id
673            GROUP BY p.id, p.name, p.category, p.stock_level
674            ORDER BY total_revenue DESC
675        "#;
676
677        let result = client.query(performance_query, None).await.unwrap();
678        assert_eq!(result, Value::String("mock_response".to_string()));
679    }
680}