Skip to main content

supabase_client_query/
upsert.rs

1use std::marker::PhantomData;
2
3use serde::de::DeserializeOwned;
4
5use supabase_client_core::SupabaseResponse;
6
7use crate::backend::QueryBackend;
8use crate::modifier::Modifiable;
9use crate::sql::{ParamStore, SqlParts};
10
11/// Builder for UPSERT (INSERT ... ON CONFLICT DO UPDATE) queries.
12/// Implements Modifiable. Call `.select()` for RETURNING clause.
13pub struct UpsertBuilder<T> {
14    pub(crate) backend: QueryBackend,
15    pub(crate) parts: SqlParts,
16    pub(crate) params: ParamStore,
17    pub(crate) _marker: PhantomData<T>,
18}
19
20impl<T> Modifiable for UpsertBuilder<T> {
21    fn parts_mut(&mut self) -> &mut SqlParts {
22        &mut self.parts
23    }
24}
25
26impl<T> UpsertBuilder<T> {
27    /// Set the conflict columns for ON CONFLICT.
28    pub fn on_conflict(mut self, columns: &[&str]) -> Self {
29        self.parts.conflict_columns = columns.iter().map(|c| c.to_string()).collect();
30        self
31    }
32
33    /// Set a constraint name for ON CONFLICT ON CONSTRAINT.
34    pub fn on_conflict_constraint(mut self, constraint: &str) -> Self {
35        self.parts.conflict_constraint = Some(constraint.to_string());
36        self
37    }
38
39    /// Use ON CONFLICT DO NOTHING instead of DO UPDATE.
40    ///
41    /// When set, duplicate rows are silently ignored instead of updated.
42    pub fn ignore_duplicates(mut self) -> Self {
43        self.parts.ignore_duplicates = true;
44        self
45    }
46
47    /// Override the schema for this query.
48    ///
49    /// Generates `"schema"."table"` instead of the default schema.
50    pub fn schema(mut self, schema: &str) -> Self {
51        self.parts.schema_override = Some(schema.to_string());
52        self
53    }
54
55    /// Add RETURNING * clause.
56    pub fn select(mut self) -> Self {
57        self.parts.returning = Some("*".to_string());
58        self
59    }
60
61    /// Add RETURNING with specific columns.
62    pub fn select_columns(mut self, columns: &str) -> Self {
63        if columns == "*" || columns.is_empty() {
64            self.parts.returning = Some("*".to_string());
65        } else {
66            let quoted = columns
67                .split(',')
68                .map(|c| {
69                    let c = c.trim();
70                    if c.contains('(') || c.contains('*') || c.contains('"') {
71                        c.to_string()
72                    } else {
73                        format!("\"{}\"", c)
74                    }
75                })
76                .collect::<Vec<_>>()
77                .join(", ");
78            self.parts.returning = Some(quoted);
79        }
80        self
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87    use crate::backend::QueryBackend;
88    use crate::sql::{ParamStore, SqlOperation, SqlParam, SqlParts};
89    use serde_json::Value as JsonValue;
90    use std::marker::PhantomData;
91    use std::sync::Arc;
92    use wiremock::matchers::{method, path};
93    use wiremock::{Mock, MockServer, ResponseTemplate};
94
95    fn make_upsert_builder() -> UpsertBuilder<JsonValue> {
96        let mut parts = SqlParts::new(SqlOperation::Upsert, "public", "users");
97        let mut params = ParamStore::new();
98        let idx1 = params.push(SqlParam::I32(1));
99        parts.set_clauses.push(("id".to_string(), idx1));
100        let idx2 = params.push(SqlParam::Text("Alice".to_string()));
101        parts.set_clauses.push(("name".to_string(), idx2));
102        UpsertBuilder {
103            backend: QueryBackend::Rest {
104                http: reqwest::Client::new(),
105                base_url: Arc::from("http://localhost"),
106                api_key: Arc::from("test-key"),
107                schema: "public".to_string(),
108            },
109            parts,
110            params,
111            _marker: PhantomData,
112        }
113    }
114
115    // ---- Builder method tests ----
116
117    #[test]
118    fn test_on_conflict_sets_columns() {
119        let builder = make_upsert_builder().on_conflict(&["id"]);
120        assert_eq!(builder.parts.conflict_columns, vec!["id".to_string()]);
121    }
122
123    #[test]
124    fn test_on_conflict_multiple_columns() {
125        let builder = make_upsert_builder().on_conflict(&["id", "email"]);
126        assert_eq!(
127            builder.parts.conflict_columns,
128            vec!["id".to_string(), "email".to_string()]
129        );
130    }
131
132    #[test]
133    fn test_on_conflict_constraint_sets_name() {
134        let builder = make_upsert_builder().on_conflict_constraint("users_pkey");
135        assert_eq!(builder.parts.conflict_constraint.as_deref(), Some("users_pkey"));
136    }
137
138    #[test]
139    fn test_ignore_duplicates_sets_flag() {
140        let builder = make_upsert_builder().ignore_duplicates();
141        assert!(builder.parts.ignore_duplicates);
142    }
143
144    #[test]
145    fn test_schema_sets_override() {
146        let builder = make_upsert_builder().schema("custom");
147        assert_eq!(builder.parts.schema_override.as_deref(), Some("custom"));
148    }
149
150    #[test]
151    fn test_select_sets_returning_star() {
152        let builder = make_upsert_builder().select();
153        assert_eq!(builder.parts.returning.as_deref(), Some("*"));
154    }
155
156    #[test]
157    fn test_select_columns_star() {
158        let builder = make_upsert_builder().select_columns("*");
159        assert_eq!(builder.parts.returning.as_deref(), Some("*"));
160    }
161
162    #[test]
163    fn test_select_columns_empty() {
164        let builder = make_upsert_builder().select_columns("");
165        assert_eq!(builder.parts.returning.as_deref(), Some("*"));
166    }
167
168    #[test]
169    fn test_select_columns_specific() {
170        let builder = make_upsert_builder().select_columns("id, name");
171        assert_eq!(builder.parts.returning.as_deref(), Some("\"id\", \"name\""));
172    }
173
174    #[test]
175    fn test_select_columns_complex() {
176        let builder = make_upsert_builder().select_columns("count(*)");
177        assert_eq!(builder.parts.returning.as_deref(), Some("count(*)"));
178    }
179
180    // ---- execute() via wiremock ----
181
182    #[tokio::test]
183    async fn test_execute_upsert_success() {
184        let mock_server = MockServer::start().await;
185        Mock::given(method("POST"))
186            .and(path("/rest/v1/users"))
187            .respond_with(
188                ResponseTemplate::new(201)
189                    .set_body_json(serde_json::json!([{"id": 1, "name": "Alice"}])),
190            )
191            .mount(&mock_server)
192            .await;
193
194        let mut parts = SqlParts::new(SqlOperation::Upsert, "public", "users");
195        let mut params = ParamStore::new();
196        let idx1 = params.push(SqlParam::I32(1));
197        parts.set_clauses.push(("id".to_string(), idx1));
198        let idx2 = params.push(SqlParam::Text("Alice".to_string()));
199        parts.set_clauses.push(("name".to_string(), idx2));
200        parts.conflict_columns = vec!["id".to_string()];
201        parts.returning = Some("*".to_string());
202
203        let builder: UpsertBuilder<JsonValue> = UpsertBuilder {
204            backend: QueryBackend::Rest {
205                http: reqwest::Client::new(),
206                base_url: Arc::from(mock_server.uri().as_str()),
207                api_key: Arc::from("test-key"),
208                schema: "public".to_string(),
209            },
210            parts,
211            params,
212            _marker: PhantomData,
213        };
214
215        let resp = builder.execute().await;
216        assert!(resp.is_ok());
217        assert_eq!(resp.data.len(), 1);
218        assert_eq!(resp.data[0]["name"], "Alice");
219        assert_eq!(resp.status, supabase_client_core::StatusCode::Created);
220    }
221
222    #[tokio::test]
223    async fn test_execute_upsert_error() {
224        let mock_server = MockServer::start().await;
225        Mock::given(method("POST"))
226            .and(path("/rest/v1/users"))
227            .respond_with(
228                ResponseTemplate::new(400)
229                    .set_body_json(serde_json::json!({
230                        "message": "Constraint violation",
231                        "code": "23514"
232                    })),
233            )
234            .mount(&mock_server)
235            .await;
236
237        let mut parts = SqlParts::new(SqlOperation::Upsert, "public", "users");
238        let mut params = ParamStore::new();
239        let idx1 = params.push(SqlParam::I32(1));
240        parts.set_clauses.push(("id".to_string(), idx1));
241        let idx2 = params.push(SqlParam::Text("Alice".to_string()));
242        parts.set_clauses.push(("name".to_string(), idx2));
243        parts.conflict_columns = vec!["id".to_string()];
244
245        let builder: UpsertBuilder<JsonValue> = UpsertBuilder {
246            backend: QueryBackend::Rest {
247                http: reqwest::Client::new(),
248                base_url: Arc::from(mock_server.uri().as_str()),
249                api_key: Arc::from("test-key"),
250                schema: "public".to_string(),
251            },
252            parts,
253            params,
254            _marker: PhantomData,
255        };
256
257        let resp = builder.execute().await;
258        assert!(resp.is_err());
259        match resp.error.as_ref().unwrap() {
260            supabase_client_core::SupabaseError::PostgRest { status, message, code } => {
261                assert_eq!(*status, 400);
262                assert_eq!(message, "Constraint violation");
263                assert_eq!(code.as_deref(), Some("23514"));
264            }
265            other => panic!("Expected PostgRest error, got {:?}", other),
266        }
267    }
268
269    #[tokio::test]
270    async fn test_execute_upsert_no_returning() {
271        let mock_server = MockServer::start().await;
272        Mock::given(method("POST"))
273            .and(path("/rest/v1/users"))
274            .respond_with(ResponseTemplate::new(201).set_body_string(""))
275            .mount(&mock_server)
276            .await;
277
278        let mut parts = SqlParts::new(SqlOperation::Upsert, "public", "users");
279        let mut params = ParamStore::new();
280        let idx1 = params.push(SqlParam::I32(1));
281        parts.set_clauses.push(("id".to_string(), idx1));
282        let idx2 = params.push(SqlParam::Text("Alice".to_string()));
283        parts.set_clauses.push(("name".to_string(), idx2));
284        parts.conflict_columns = vec!["id".to_string()];
285
286        let builder: UpsertBuilder<JsonValue> = UpsertBuilder {
287            backend: QueryBackend::Rest {
288                http: reqwest::Client::new(),
289                base_url: Arc::from(mock_server.uri().as_str()),
290                api_key: Arc::from("test-key"),
291                schema: "public".to_string(),
292            },
293            parts,
294            params,
295            _marker: PhantomData,
296        };
297
298        let resp = builder.execute().await;
299        assert!(resp.is_ok());
300        assert!(resp.data.is_empty());
301    }
302}
303
304// REST-only mode: only DeserializeOwned + Send needed
305#[cfg(not(feature = "direct-sql"))]
306impl<T> UpsertBuilder<T>
307where
308    T: DeserializeOwned + Send,
309{
310    /// Execute the UPSERT query.
311    pub async fn execute(self) -> SupabaseResponse<T> {
312        let QueryBackend::Rest { ref http, ref base_url, ref api_key, ref schema } = self.backend;
313        let (url, headers, body) = match crate::postgrest::build_postgrest_upsert(
314            base_url, &self.parts, &self.params,
315        ) {
316            Ok(r) => r,
317            Err(e) => return SupabaseResponse::error(
318                supabase_client_core::SupabaseError::QueryBuilder(e),
319            ),
320        };
321        crate::postgrest_execute::execute_rest(
322            http, reqwest::Method::POST, &url, headers, Some(body), api_key, schema, &self.parts,
323        ).await
324    }
325}
326
327// Direct-SQL mode: additional FromRow + Unpin bounds
328#[cfg(feature = "direct-sql")]
329impl<T> UpsertBuilder<T>
330where
331    T: DeserializeOwned + Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>,
332{
333    /// Execute the UPSERT query.
334    pub async fn execute(self) -> SupabaseResponse<T> {
335        match &self.backend {
336            QueryBackend::Rest { http, base_url, api_key, schema } => {
337                let (url, headers, body) = match crate::postgrest::build_postgrest_upsert(
338                    base_url, &self.parts, &self.params,
339                ) {
340                    Ok(r) => r,
341                    Err(e) => return SupabaseResponse::error(
342                        supabase_client_core::SupabaseError::QueryBuilder(e),
343                    ),
344                };
345                crate::postgrest_execute::execute_rest(
346                    http, reqwest::Method::POST, &url, headers, Some(body), api_key, schema, &self.parts,
347                ).await
348            }
349            QueryBackend::DirectSql { pool } => {
350                crate::execute::execute_typed::<T>(pool, &self.parts, &self.params).await
351            }
352        }
353    }
354}