1use std::marker::PhantomData;
2
3use serde::de::DeserializeOwned;
4
5use supabase_client_core::SupabaseResponse;
6
7use crate::backend::QueryBackend;
8use crate::modifier::Modifiable;
9use crate::sql::{ParamStore, SqlParts};
10
11pub struct UpsertBuilder<T> {
14 pub(crate) backend: QueryBackend,
15 pub(crate) parts: SqlParts,
16 pub(crate) params: ParamStore,
17 pub(crate) _marker: PhantomData<T>,
18}
19
20impl<T> Modifiable for UpsertBuilder<T> {
21 fn parts_mut(&mut self) -> &mut SqlParts {
22 &mut self.parts
23 }
24}
25
26impl<T> UpsertBuilder<T> {
27 pub fn on_conflict(mut self, columns: &[&str]) -> Self {
29 self.parts.conflict_columns = columns.iter().map(|c| c.to_string()).collect();
30 self
31 }
32
33 pub fn on_conflict_constraint(mut self, constraint: &str) -> Self {
35 self.parts.conflict_constraint = Some(constraint.to_string());
36 self
37 }
38
39 pub fn ignore_duplicates(mut self) -> Self {
43 self.parts.ignore_duplicates = true;
44 self
45 }
46
47 pub fn schema(mut self, schema: &str) -> Self {
51 self.parts.schema_override = Some(schema.to_string());
52 self
53 }
54
55 pub fn select(mut self) -> Self {
57 self.parts.returning = Some("*".to_string());
58 self
59 }
60
61 pub fn select_columns(mut self, columns: &str) -> Self {
63 if columns == "*" || columns.is_empty() {
64 self.parts.returning = Some("*".to_string());
65 } else {
66 let quoted = columns
67 .split(',')
68 .map(|c| {
69 let c = c.trim();
70 if c.contains('(') || c.contains('*') || c.contains('"') {
71 c.to_string()
72 } else {
73 format!("\"{}\"", c)
74 }
75 })
76 .collect::<Vec<_>>()
77 .join(", ");
78 self.parts.returning = Some(quoted);
79 }
80 self
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87 use crate::backend::QueryBackend;
88 use crate::sql::{ParamStore, SqlOperation, SqlParam, SqlParts};
89 use serde_json::Value as JsonValue;
90 use std::marker::PhantomData;
91 use std::sync::Arc;
92 use wiremock::matchers::{method, path};
93 use wiremock::{Mock, MockServer, ResponseTemplate};
94
95 fn make_upsert_builder() -> UpsertBuilder<JsonValue> {
96 let mut parts = SqlParts::new(SqlOperation::Upsert, "public", "users");
97 let mut params = ParamStore::new();
98 let idx1 = params.push(SqlParam::I32(1));
99 parts.set_clauses.push(("id".to_string(), idx1));
100 let idx2 = params.push(SqlParam::Text("Alice".to_string()));
101 parts.set_clauses.push(("name".to_string(), idx2));
102 UpsertBuilder {
103 backend: QueryBackend::Rest {
104 http: reqwest::Client::new(),
105 base_url: Arc::from("http://localhost"),
106 api_key: Arc::from("test-key"),
107 schema: "public".to_string(),
108 },
109 parts,
110 params,
111 _marker: PhantomData,
112 }
113 }
114
115 #[test]
118 fn test_on_conflict_sets_columns() {
119 let builder = make_upsert_builder().on_conflict(&["id"]);
120 assert_eq!(builder.parts.conflict_columns, vec!["id".to_string()]);
121 }
122
123 #[test]
124 fn test_on_conflict_multiple_columns() {
125 let builder = make_upsert_builder().on_conflict(&["id", "email"]);
126 assert_eq!(
127 builder.parts.conflict_columns,
128 vec!["id".to_string(), "email".to_string()]
129 );
130 }
131
132 #[test]
133 fn test_on_conflict_constraint_sets_name() {
134 let builder = make_upsert_builder().on_conflict_constraint("users_pkey");
135 assert_eq!(builder.parts.conflict_constraint.as_deref(), Some("users_pkey"));
136 }
137
138 #[test]
139 fn test_ignore_duplicates_sets_flag() {
140 let builder = make_upsert_builder().ignore_duplicates();
141 assert!(builder.parts.ignore_duplicates);
142 }
143
144 #[test]
145 fn test_schema_sets_override() {
146 let builder = make_upsert_builder().schema("custom");
147 assert_eq!(builder.parts.schema_override.as_deref(), Some("custom"));
148 }
149
150 #[test]
151 fn test_select_sets_returning_star() {
152 let builder = make_upsert_builder().select();
153 assert_eq!(builder.parts.returning.as_deref(), Some("*"));
154 }
155
156 #[test]
157 fn test_select_columns_star() {
158 let builder = make_upsert_builder().select_columns("*");
159 assert_eq!(builder.parts.returning.as_deref(), Some("*"));
160 }
161
162 #[test]
163 fn test_select_columns_empty() {
164 let builder = make_upsert_builder().select_columns("");
165 assert_eq!(builder.parts.returning.as_deref(), Some("*"));
166 }
167
168 #[test]
169 fn test_select_columns_specific() {
170 let builder = make_upsert_builder().select_columns("id, name");
171 assert_eq!(builder.parts.returning.as_deref(), Some("\"id\", \"name\""));
172 }
173
174 #[test]
175 fn test_select_columns_complex() {
176 let builder = make_upsert_builder().select_columns("count(*)");
177 assert_eq!(builder.parts.returning.as_deref(), Some("count(*)"));
178 }
179
180 #[tokio::test]
183 async fn test_execute_upsert_success() {
184 let mock_server = MockServer::start().await;
185 Mock::given(method("POST"))
186 .and(path("/rest/v1/users"))
187 .respond_with(
188 ResponseTemplate::new(201)
189 .set_body_json(serde_json::json!([{"id": 1, "name": "Alice"}])),
190 )
191 .mount(&mock_server)
192 .await;
193
194 let mut parts = SqlParts::new(SqlOperation::Upsert, "public", "users");
195 let mut params = ParamStore::new();
196 let idx1 = params.push(SqlParam::I32(1));
197 parts.set_clauses.push(("id".to_string(), idx1));
198 let idx2 = params.push(SqlParam::Text("Alice".to_string()));
199 parts.set_clauses.push(("name".to_string(), idx2));
200 parts.conflict_columns = vec!["id".to_string()];
201 parts.returning = Some("*".to_string());
202
203 let builder: UpsertBuilder<JsonValue> = UpsertBuilder {
204 backend: QueryBackend::Rest {
205 http: reqwest::Client::new(),
206 base_url: Arc::from(mock_server.uri().as_str()),
207 api_key: Arc::from("test-key"),
208 schema: "public".to_string(),
209 },
210 parts,
211 params,
212 _marker: PhantomData,
213 };
214
215 let resp = builder.execute().await;
216 assert!(resp.is_ok());
217 assert_eq!(resp.data.len(), 1);
218 assert_eq!(resp.data[0]["name"], "Alice");
219 assert_eq!(resp.status, supabase_client_core::StatusCode::Created);
220 }
221
222 #[tokio::test]
223 async fn test_execute_upsert_error() {
224 let mock_server = MockServer::start().await;
225 Mock::given(method("POST"))
226 .and(path("/rest/v1/users"))
227 .respond_with(
228 ResponseTemplate::new(400)
229 .set_body_json(serde_json::json!({
230 "message": "Constraint violation",
231 "code": "23514"
232 })),
233 )
234 .mount(&mock_server)
235 .await;
236
237 let mut parts = SqlParts::new(SqlOperation::Upsert, "public", "users");
238 let mut params = ParamStore::new();
239 let idx1 = params.push(SqlParam::I32(1));
240 parts.set_clauses.push(("id".to_string(), idx1));
241 let idx2 = params.push(SqlParam::Text("Alice".to_string()));
242 parts.set_clauses.push(("name".to_string(), idx2));
243 parts.conflict_columns = vec!["id".to_string()];
244
245 let builder: UpsertBuilder<JsonValue> = UpsertBuilder {
246 backend: QueryBackend::Rest {
247 http: reqwest::Client::new(),
248 base_url: Arc::from(mock_server.uri().as_str()),
249 api_key: Arc::from("test-key"),
250 schema: "public".to_string(),
251 },
252 parts,
253 params,
254 _marker: PhantomData,
255 };
256
257 let resp = builder.execute().await;
258 assert!(resp.is_err());
259 match resp.error.as_ref().unwrap() {
260 supabase_client_core::SupabaseError::PostgRest { status, message, code } => {
261 assert_eq!(*status, 400);
262 assert_eq!(message, "Constraint violation");
263 assert_eq!(code.as_deref(), Some("23514"));
264 }
265 other => panic!("Expected PostgRest error, got {:?}", other),
266 }
267 }
268
269 #[tokio::test]
270 async fn test_execute_upsert_no_returning() {
271 let mock_server = MockServer::start().await;
272 Mock::given(method("POST"))
273 .and(path("/rest/v1/users"))
274 .respond_with(ResponseTemplate::new(201).set_body_string(""))
275 .mount(&mock_server)
276 .await;
277
278 let mut parts = SqlParts::new(SqlOperation::Upsert, "public", "users");
279 let mut params = ParamStore::new();
280 let idx1 = params.push(SqlParam::I32(1));
281 parts.set_clauses.push(("id".to_string(), idx1));
282 let idx2 = params.push(SqlParam::Text("Alice".to_string()));
283 parts.set_clauses.push(("name".to_string(), idx2));
284 parts.conflict_columns = vec!["id".to_string()];
285
286 let builder: UpsertBuilder<JsonValue> = UpsertBuilder {
287 backend: QueryBackend::Rest {
288 http: reqwest::Client::new(),
289 base_url: Arc::from(mock_server.uri().as_str()),
290 api_key: Arc::from("test-key"),
291 schema: "public".to_string(),
292 },
293 parts,
294 params,
295 _marker: PhantomData,
296 };
297
298 let resp = builder.execute().await;
299 assert!(resp.is_ok());
300 assert!(resp.data.is_empty());
301 }
302}
303
304#[cfg(not(feature = "direct-sql"))]
306impl<T> UpsertBuilder<T>
307where
308 T: DeserializeOwned + Send,
309{
310 pub async fn execute(self) -> SupabaseResponse<T> {
312 let QueryBackend::Rest { ref http, ref base_url, ref api_key, ref schema } = self.backend;
313 let (url, headers, body) = match crate::postgrest::build_postgrest_upsert(
314 base_url, &self.parts, &self.params,
315 ) {
316 Ok(r) => r,
317 Err(e) => return SupabaseResponse::error(
318 supabase_client_core::SupabaseError::QueryBuilder(e),
319 ),
320 };
321 crate::postgrest_execute::execute_rest(
322 http, reqwest::Method::POST, &url, headers, Some(body), api_key, schema, &self.parts,
323 ).await
324 }
325}
326
327#[cfg(feature = "direct-sql")]
329impl<T> UpsertBuilder<T>
330where
331 T: DeserializeOwned + Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>,
332{
333 pub async fn execute(self) -> SupabaseResponse<T> {
335 match &self.backend {
336 QueryBackend::Rest { http, base_url, api_key, schema } => {
337 let (url, headers, body) = match crate::postgrest::build_postgrest_upsert(
338 base_url, &self.parts, &self.params,
339 ) {
340 Ok(r) => r,
341 Err(e) => return SupabaseResponse::error(
342 supabase_client_core::SupabaseError::QueryBuilder(e),
343 ),
344 };
345 crate::postgrest_execute::execute_rest(
346 http, reqwest::Method::POST, &url, headers, Some(body), api_key, schema, &self.parts,
347 ).await
348 }
349 QueryBackend::DirectSql { pool } => {
350 crate::execute::execute_typed::<T>(pool, &self.parts, &self.params).await
351 }
352 }
353 }
354}