1use crate::server_fn::ServerFnError;
7use serde::{Deserialize, Serialize, de::DeserializeOwned};
8use std::marker::PhantomData;
9
10#[derive(Debug, Clone, Serialize, Deserialize, Default)]
12pub enum FilterOp {
13 #[default]
15 Exact,
16 IExact,
18 Contains,
20 IContains,
22 Gt,
24 Gte,
26 Lt,
28 Lte,
30 StartsWith,
32 IStartsWith,
34 EndsWith,
36 IEndsWith,
38 In,
40 IsNull,
42 Range,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct Filter {
49 pub field: String,
51 pub op: FilterOp,
53 pub value: serde_json::Value,
55 pub exclude: bool,
57}
58
59impl Filter {
60 pub fn exact(field: impl Into<String>, value: impl Serialize) -> Self {
62 Self {
63 field: field.into(),
64 op: FilterOp::Exact,
65 value: serde_json::to_value(value).unwrap_or(serde_json::Value::Null),
66 exclude: false,
67 }
68 }
69
70 pub fn with_op(field: impl Into<String>, op: FilterOp, value: impl Serialize) -> Self {
72 Self {
73 field: field.into(),
74 op,
75 value: serde_json::to_value(value).unwrap_or(serde_json::Value::Null),
76 exclude: false,
77 }
78 }
79
80 pub fn negate(mut self) -> Self {
82 self.exclude = !self.exclude;
83 self
84 }
85
86 pub fn to_query_param(&self) -> (String, String) {
88 let key = match self.op {
89 FilterOp::Exact => self.field.clone(),
90 FilterOp::IExact => format!("{}__iexact", self.field),
91 FilterOp::Contains => format!("{}__contains", self.field),
92 FilterOp::IContains => format!("{}__icontains", self.field),
93 FilterOp::Gt => format!("{}__gt", self.field),
94 FilterOp::Gte => format!("{}__gte", self.field),
95 FilterOp::Lt => format!("{}__lt", self.field),
96 FilterOp::Lte => format!("{}__lte", self.field),
97 FilterOp::StartsWith => format!("{}__startswith", self.field),
98 FilterOp::IStartsWith => format!("{}__istartswith", self.field),
99 FilterOp::EndsWith => format!("{}__endswith", self.field),
100 FilterOp::IEndsWith => format!("{}__iendswith", self.field),
101 FilterOp::In => format!("{}__in", self.field),
102 FilterOp::IsNull => format!("{}__isnull", self.field),
103 FilterOp::Range => format!("{}__range", self.field),
104 };
105
106 let value = match &self.value {
107 serde_json::Value::String(s) => s.clone(),
108 serde_json::Value::Number(n) => n.to_string(),
109 serde_json::Value::Bool(b) => b.to_string(),
110 serde_json::Value::Array(arr) => arr
111 .iter()
112 .map(|v| match v {
113 serde_json::Value::String(s) => s.clone(),
114 other => other.to_string(),
115 })
116 .collect::<Vec<_>>()
117 .join(","),
118 serde_json::Value::Null => "null".to_string(),
119 other => other.to_string(),
120 };
121
122 (key, value)
123 }
124}
125
126#[derive(Debug, Clone)]
131pub struct ApiQuerySet<T> {
132 endpoint: String,
134 filters: Vec<Filter>,
136 ordering: Vec<String>,
138 limit: Option<usize>,
140 offset: Option<usize>,
142 fields: Vec<String>,
144 _marker: PhantomData<T>,
146}
147
148impl<T> ApiQuerySet<T>
149where
150 T: Serialize + DeserializeOwned,
151{
152 pub fn new(endpoint: impl Into<String>) -> Self {
154 Self {
155 endpoint: endpoint.into(),
156 filters: Vec::new(),
157 ordering: Vec::new(),
158 limit: None,
159 offset: None,
160 fields: Vec::new(),
161 _marker: PhantomData,
162 }
163 }
164
165 pub fn filter(mut self, field: impl Into<String>, value: impl Serialize) -> Self {
172 self.filters.push(Filter::exact(field, value));
173 self
174 }
175
176 pub fn filter_op(
183 mut self,
184 field: impl Into<String>,
185 op: FilterOp,
186 value: impl Serialize,
187 ) -> Self {
188 self.filters.push(Filter::with_op(field, op, value));
189 self
190 }
191
192 pub fn exclude(mut self, field: impl Into<String>, value: impl Serialize) -> Self {
199 self.filters.push(Filter::exact(field, value).negate());
200 self
201 }
202
203 pub fn order_by(mut self, fields: &[&str]) -> Self {
212 self.ordering = fields.iter().map(|s| (*s).to_string()).collect();
213 self
214 }
215
216 pub fn limit(mut self, n: usize) -> Self {
223 self.limit = Some(n);
224 self
225 }
226
227 pub fn offset(mut self, n: usize) -> Self {
234 self.offset = Some(n);
235 self
236 }
237
238 pub fn only(mut self, fields: &[&str]) -> Self {
245 self.fields = fields.iter().map(|s| (*s).to_string()).collect();
246 self
247 }
248
249 pub fn all_clone(&self) -> Self {
251 Self::new(&self.endpoint)
252 }
253
254 pub fn build_url(&self) -> String {
256 let mut params: Vec<(String, String)> = Vec::new();
257
258 for filter in &self.filters {
260 let (key, value) = filter.to_query_param();
261 if filter.exclude {
262 params.push((format!("exclude__{}", key), value));
263 } else {
264 params.push((key, value));
265 }
266 }
267
268 if !self.ordering.is_empty() {
270 params.push(("ordering".to_string(), self.ordering.join(",")));
271 }
272
273 if let Some(limit) = self.limit {
275 params.push(("limit".to_string(), limit.to_string()));
276 }
277 if let Some(offset) = self.offset {
278 params.push(("offset".to_string(), offset.to_string()));
279 }
280
281 if !self.fields.is_empty() {
283 params.push(("fields".to_string(), self.fields.join(",")));
284 }
285
286 if params.is_empty() {
288 self.endpoint.clone()
289 } else {
290 let query_string = params
291 .iter()
292 .map(|(k, v)| format!("{}={}", urlencoding::encode(k), urlencoding::encode(v)))
293 .collect::<Vec<_>>()
294 .join("&");
295 format!("{}?{}", self.endpoint, query_string)
296 }
297 }
298
299 #[cfg(target_arch = "wasm32")]
301 pub async fn all(&self) -> Result<Vec<T>, ServerFnError> {
302 use crate::csrf::csrf_headers;
303 use gloo_net::http::Request;
304
305 let url = self.build_url();
306 let mut request = Request::get(&url);
307
308 if let Some((header_name, header_value)) = csrf_headers() {
310 request = request.header(header_name, &header_value);
311 }
312
313 let response = request
314 .send()
315 .await
316 .map_err(|e| ServerFnError::Network(e.to_string()))?;
317
318 if !response.ok() {
319 return Err(ServerFnError::Server {
320 status: response.status(),
321 message: response.status_text(),
322 });
323 }
324
325 response
326 .json()
327 .await
328 .map_err(|e| ServerFnError::Deserialization(e.to_string()))
329 }
330
331 #[cfg(not(target_arch = "wasm32"))]
333 pub async fn all(&self) -> Result<Vec<T>, ServerFnError> {
334 Err(ServerFnError::Network(
335 "API calls not supported outside WASM".to_string(),
336 ))
337 }
338
339 #[cfg(target_arch = "wasm32")]
341 pub async fn first(&self) -> Result<Option<T>, ServerFnError>
342 where
343 T: Clone,
344 {
345 let mut queryset = self.clone();
346 queryset.limit = Some(1);
347 let results = queryset.all().await?;
348 Ok(results.into_iter().next())
349 }
350
351 #[cfg(not(target_arch = "wasm32"))]
353 pub async fn first(&self) -> Result<Option<T>, ServerFnError> {
354 Err(ServerFnError::Network(
355 "API calls not supported outside WASM".to_string(),
356 ))
357 }
358
359 #[cfg(target_arch = "wasm32")]
361 pub async fn get(&self, pk: impl std::fmt::Display) -> Result<T, ServerFnError> {
362 use crate::csrf::csrf_headers;
363 use gloo_net::http::Request;
364
365 let url = format!("{}{}/", self.endpoint.trim_end_matches('/'), pk);
366 let mut builder = Request::get(&url);
367
368 if let Some((header_name, header_value)) = csrf_headers() {
369 builder = builder.header(header_name, &header_value);
370 }
371
372 let response = builder
373 .send()
374 .await
375 .map_err(|e| ServerFnError::Network(e.to_string()))?;
376
377 if !response.ok() {
378 return Err(ServerFnError::Server {
379 status: response.status(),
380 message: response.status_text(),
381 });
382 }
383
384 response
385 .json()
386 .await
387 .map_err(|e| ServerFnError::Deserialization(e.to_string()))
388 }
389
390 #[cfg(not(target_arch = "wasm32"))]
392 pub async fn get(&self, _pk: impl std::fmt::Display) -> Result<T, ServerFnError> {
393 Err(ServerFnError::Network(
394 "API calls not supported outside WASM".to_string(),
395 ))
396 }
397
398 #[cfg(target_arch = "wasm32")]
400 pub async fn count(&self) -> Result<usize, ServerFnError> {
401 use crate::csrf::csrf_headers;
402 use gloo_net::http::Request;
403
404 let url = format!("{}?count=true", self.build_url());
406 let mut builder = Request::get(&url);
407
408 if let Some((header_name, header_value)) = csrf_headers() {
409 builder = builder.header(header_name, &header_value);
410 }
411
412 let response = builder
413 .send()
414 .await
415 .map_err(|e| ServerFnError::Network(e.to_string()))?;
416
417 if !response.ok() {
418 return Err(ServerFnError::Server {
419 status: response.status(),
420 message: response.status_text(),
421 });
422 }
423
424 #[derive(Deserialize)]
425 struct CountResponse {
426 count: usize,
427 }
428
429 let result: CountResponse = response
430 .json()
431 .await
432 .map_err(|e| ServerFnError::Deserialization(e.to_string()))?;
433
434 Ok(result.count)
435 }
436
437 #[cfg(not(target_arch = "wasm32"))]
439 pub async fn count(&self) -> Result<usize, ServerFnError> {
440 Err(ServerFnError::Network(
441 "API calls not supported outside WASM".to_string(),
442 ))
443 }
444
445 pub async fn exists(&self) -> Result<bool, ServerFnError>
447 where
448 Self: Clone,
449 {
450 let count = self.clone().limit(1).count().await?;
451 Ok(count > 0)
452 }
453
454 #[cfg(target_arch = "wasm32")]
456 pub async fn create(&self, data: &T) -> Result<T, ServerFnError> {
457 use crate::csrf::csrf_headers;
458 use gloo_net::http::Request;
459
460 let mut builder = Request::post(&self.endpoint);
461
462 if let Some((header_name, header_value)) = csrf_headers() {
463 builder = builder.header(header_name, &header_value);
464 }
465
466 let request = builder
467 .json(data)
468 .map_err(|e| ServerFnError::Serialization(e.to_string()))?;
469
470 let response = request
471 .send()
472 .await
473 .map_err(|e| ServerFnError::Network(e.to_string()))?;
474
475 if !response.ok() {
476 return Err(ServerFnError::Server {
477 status: response.status(),
478 message: response.status_text(),
479 });
480 }
481
482 response
483 .json()
484 .await
485 .map_err(|e| ServerFnError::Deserialization(e.to_string()))
486 }
487
488 #[cfg(not(target_arch = "wasm32"))]
490 pub async fn create(&self, _data: &T) -> Result<T, ServerFnError> {
491 Err(ServerFnError::Network(
492 "API calls not supported outside WASM".to_string(),
493 ))
494 }
495
496 #[cfg(target_arch = "wasm32")]
498 pub async fn update(&self, pk: impl std::fmt::Display, data: &T) -> Result<T, ServerFnError> {
499 use crate::csrf::csrf_headers;
500 use gloo_net::http::Request;
501
502 let url = format!("{}{}/", self.endpoint.trim_end_matches('/'), pk);
503 let mut builder = Request::put(&url);
504
505 if let Some((header_name, header_value)) = csrf_headers() {
506 builder = builder.header(header_name, &header_value);
507 }
508
509 let request = builder
510 .json(data)
511 .map_err(|e| ServerFnError::Serialization(e.to_string()))?;
512
513 let response = request
514 .send()
515 .await
516 .map_err(|e| ServerFnError::Network(e.to_string()))?;
517
518 if !response.ok() {
519 return Err(ServerFnError::Server {
520 status: response.status(),
521 message: response.status_text(),
522 });
523 }
524
525 response
526 .json()
527 .await
528 .map_err(|e| ServerFnError::Deserialization(e.to_string()))
529 }
530
531 #[cfg(not(target_arch = "wasm32"))]
533 pub async fn update(&self, _pk: impl std::fmt::Display, _data: &T) -> Result<T, ServerFnError> {
534 Err(ServerFnError::Network(
535 "API calls not supported outside WASM".to_string(),
536 ))
537 }
538
539 #[cfg(target_arch = "wasm32")]
541 pub async fn partial_update(
542 &self,
543 pk: impl std::fmt::Display,
544 data: &serde_json::Value,
545 ) -> Result<T, ServerFnError> {
546 use crate::csrf::csrf_headers;
547 use gloo_net::http::Request;
548
549 let url = format!("{}{}/", self.endpoint.trim_end_matches('/'), pk);
550 let mut builder = Request::patch(&url);
551
552 if let Some((header_name, header_value)) = csrf_headers() {
553 builder = builder.header(header_name, &header_value);
554 }
555
556 let request = builder
557 .json(data)
558 .map_err(|e| ServerFnError::Serialization(e.to_string()))?;
559
560 let response = request
561 .send()
562 .await
563 .map_err(|e| ServerFnError::Network(e.to_string()))?;
564
565 if !response.ok() {
566 return Err(ServerFnError::Server {
567 status: response.status(),
568 message: response.status_text(),
569 });
570 }
571
572 response
573 .json()
574 .await
575 .map_err(|e| ServerFnError::Deserialization(e.to_string()))
576 }
577
578 #[cfg(not(target_arch = "wasm32"))]
580 pub async fn partial_update(
581 &self,
582 _pk: impl std::fmt::Display,
583 _data: &serde_json::Value,
584 ) -> Result<T, ServerFnError> {
585 Err(ServerFnError::Network(
586 "API calls not supported outside WASM".to_string(),
587 ))
588 }
589
590 #[cfg(target_arch = "wasm32")]
592 pub async fn delete(&self, pk: impl std::fmt::Display) -> Result<(), ServerFnError> {
593 use crate::csrf::csrf_headers;
594 use gloo_net::http::Request;
595
596 let url = format!("{}{}/", self.endpoint.trim_end_matches('/'), pk);
597 let mut builder = Request::delete(&url);
598
599 if let Some((header_name, header_value)) = csrf_headers() {
600 builder = builder.header(header_name, &header_value);
601 }
602
603 let response = builder
604 .send()
605 .await
606 .map_err(|e| ServerFnError::Network(e.to_string()))?;
607
608 if !response.ok() {
609 return Err(ServerFnError::Server {
610 status: response.status(),
611 message: response.status_text(),
612 });
613 }
614
615 Ok(())
616 }
617
618 #[cfg(not(target_arch = "wasm32"))]
620 pub async fn delete(&self, _pk: impl std::fmt::Display) -> Result<(), ServerFnError> {
621 Err(ServerFnError::Network(
622 "API calls not supported outside WASM".to_string(),
623 ))
624 }
625}
626
627#[cfg(test)]
628mod tests {
629 use super::*;
630
631 #[test]
632 fn test_filter_exact() {
633 let filter = Filter::exact("name", "test");
634 assert_eq!(filter.field, "name");
635 assert!(!filter.exclude);
636 let (key, value) = filter.to_query_param();
637 assert_eq!(key, "name");
638 assert_eq!(value, "test");
639 }
640
641 #[test]
642 fn test_filter_with_op() {
643 let filter = Filter::with_op("age", FilterOp::Gte, 18);
644 let (key, value) = filter.to_query_param();
645 assert_eq!(key, "age__gte");
646 assert_eq!(value, "18");
647 }
648
649 #[test]
650 fn test_filter_negate() {
651 let filter = Filter::exact("status", "banned").negate();
652 assert!(filter.exclude);
653 }
654
655 #[test]
656 fn test_queryset_build_url_simple() {
657 let qs: ApiQuerySet<serde_json::Value> = ApiQuerySet::new("/api/users/");
658 assert_eq!(qs.build_url(), "/api/users/");
659 }
660
661 #[test]
662 fn test_queryset_build_url_with_filters() {
663 let qs: ApiQuerySet<serde_json::Value> = ApiQuerySet::new("/api/users/")
664 .filter("is_active", true)
665 .filter_op("age", FilterOp::Gte, 18);
666
667 let url = qs.build_url();
668 assert!(url.contains("is_active=true"));
669 assert!(url.contains("age__gte=18"));
670 }
671
672 #[test]
673 fn test_queryset_build_url_with_ordering() {
674 let qs: ApiQuerySet<serde_json::Value> =
675 ApiQuerySet::new("/api/users/").order_by(&["-created_at", "username"]);
676
677 let url = qs.build_url();
678 assert!(url.contains("ordering=-created_at%2Cusername"));
679 }
680
681 #[test]
682 fn test_queryset_build_url_with_pagination() {
683 let qs: ApiQuerySet<serde_json::Value> =
684 ApiQuerySet::new("/api/users/").limit(10).offset(20);
685
686 let url = qs.build_url();
687 assert!(url.contains("limit=10"));
688 assert!(url.contains("offset=20"));
689 }
690
691 #[test]
692 fn test_queryset_build_url_with_fields() {
693 let qs: ApiQuerySet<serde_json::Value> =
694 ApiQuerySet::new("/api/users/").only(&["id", "username"]);
695
696 let url = qs.build_url();
697 assert!(url.contains("fields=id%2Cusername"));
698 }
699
700 #[test]
701 fn test_queryset_chain() {
702 let qs: ApiQuerySet<serde_json::Value> = ApiQuerySet::new("/api/users/")
703 .filter("is_active", true)
704 .exclude("role", "admin")
705 .order_by(&["-created_at"])
706 .limit(10)
707 .offset(0);
708
709 let url = qs.build_url();
710 assert!(url.starts_with("/api/users/?"));
711 assert!(url.contains("is_active=true"));
712 assert!(url.contains("exclude__role=admin"));
713 assert!(url.contains("ordering=-created_at"));
714 assert!(url.contains("limit=10"));
715 }
716
717 #[test]
718 fn test_filter_in_list() {
719 let filter = Filter::with_op("id", FilterOp::In, vec![1, 2, 3]);
720 let (key, value) = filter.to_query_param();
721 assert_eq!(key, "id__in");
722 assert_eq!(value, "1,2,3");
723 }
724}