1use crate::client::Client;
4use crate::error::{Error, Result};
5use crate::internal::{apply_pagination, push_opt};
6use crate::pagination::{FetchFn, Page, PageStream};
7use crate::resources::agencies::urlencoding;
8use crate::Record;
9use bon::Builder;
10use std::collections::BTreeMap;
11use std::sync::Arc;
12
13#[derive(Debug, Clone, Default, Builder, PartialEq, Eq)]
17#[non_exhaustive]
18pub struct ListEntitiesOptions {
19 #[builder(into)]
22 pub page: Option<u32>,
23 #[builder(into)]
25 pub limit: Option<u32>,
26 #[builder(into)]
28 pub cursor: Option<String>,
29 #[builder(into)]
31 pub shape: Option<String>,
32 #[builder(default)]
34 pub flat: bool,
35 #[builder(default)]
37 pub flat_lists: bool,
38
39 #[builder(into)]
42 pub search: Option<String>,
43 #[builder(into)]
45 pub cage_code: Option<String>,
46 #[builder(into)]
48 pub naics: Option<String>,
49 #[builder(into)]
51 pub name: Option<String>,
52 #[builder(into)]
54 pub psc: Option<String>,
55 #[builder(into)]
57 pub purpose_of_registration_code: Option<String>,
58 #[builder(into)]
60 pub socioeconomic: Option<String>,
61 #[builder(into)]
63 pub state: Option<String>,
64 #[builder(into)]
66 pub total_awards_obligated_gte: Option<String>,
67 #[builder(into)]
69 pub total_awards_obligated_lte: Option<String>,
70 #[builder(into)]
72 pub uei: Option<String>,
73 #[builder(into)]
75 pub zip_code: Option<String>,
76
77 #[builder(default)]
79 pub extra: BTreeMap<String, String>,
80}
81
82impl ListEntitiesOptions {
83 pub(crate) fn to_query(&self) -> Vec<(String, String)> {
84 let mut q = Vec::new();
85 apply_pagination(
86 &mut q,
87 self.page,
88 self.limit,
89 self.cursor.as_deref(),
90 self.shape.as_deref(),
91 self.flat,
92 self.flat_lists,
93 );
94 push_opt(&mut q, "search", self.search.as_deref());
95 push_opt(&mut q, "cage_code", self.cage_code.as_deref());
96 push_opt(&mut q, "naics", self.naics.as_deref());
97 push_opt(&mut q, "name", self.name.as_deref());
98 push_opt(&mut q, "psc", self.psc.as_deref());
99 push_opt(
100 &mut q,
101 "purpose_of_registration_code",
102 self.purpose_of_registration_code.as_deref(),
103 );
104 push_opt(&mut q, "socioeconomic", self.socioeconomic.as_deref());
105 push_opt(&mut q, "state", self.state.as_deref());
106 push_opt(
107 &mut q,
108 "total_awards_obligated_gte",
109 self.total_awards_obligated_gte.as_deref(),
110 );
111 push_opt(
112 &mut q,
113 "total_awards_obligated_lte",
114 self.total_awards_obligated_lte.as_deref(),
115 );
116 push_opt(&mut q, "uei", self.uei.as_deref());
117 push_opt(&mut q, "zip_code", self.zip_code.as_deref());
118
119 for (k, v) in &self.extra {
120 if !v.is_empty() {
121 q.push((k.clone(), v.clone()));
122 }
123 }
124 q
125 }
126}
127
128#[derive(Debug, Clone, Default, Builder, PartialEq, Eq)]
130#[non_exhaustive]
131pub struct GetEntityOptions {
132 #[builder(into)]
134 pub shape: Option<String>,
135 #[builder(default)]
137 pub flat: bool,
138 #[builder(default)]
140 pub flat_lists: bool,
141}
142
143impl GetEntityOptions {
144 pub(crate) fn to_query(&self) -> Vec<(String, String)> {
145 let mut q = Vec::new();
146 push_opt(&mut q, "shape", self.shape.as_deref());
147 if self.flat {
148 q.push(("flat".into(), "true".into()));
149 }
150 if self.flat_lists {
151 q.push(("flat_lists".into(), "true".into()));
152 }
153 q
154 }
155}
156
157impl Client {
158 pub async fn list_entities(&self, opts: ListEntitiesOptions) -> Result<Page<Record>> {
160 let q = opts.to_query();
161 let bytes = self.get_bytes("/api/entities/", &q).await?;
162 Page::decode(&bytes)
163 }
164
165 pub async fn get_entity(&self, uei: &str, opts: Option<GetEntityOptions>) -> Result<Record> {
167 if uei.is_empty() {
168 return Err(Error::Validation {
169 message: "get_entity: uei is required".into(),
170 response: None,
171 });
172 }
173 let q = opts.unwrap_or_default().to_query();
174 let path = format!("/api/entities/{}/", urlencoding(uei));
175 self.get_json::<Record>(&path, &q).await
176 }
177
178 pub fn iterate_entities(&self, opts: ListEntitiesOptions) -> PageStream<Record> {
180 let opts = Arc::new(opts);
181 let fetch: FetchFn<Record> = Box::new(move |client, page, cursor| {
182 let mut next = (*opts).clone();
183 next.page = page;
184 next.cursor = cursor;
185 Box::pin(async move { client.list_entities(next).await })
186 });
187 PageStream::new(self.clone(), fetch)
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 fn get_q(q: &[(String, String)], k: &str) -> Option<String> {
196 q.iter().find(|(kk, _)| kk == k).map(|(_, v)| v.clone())
197 }
198
199 #[test]
200 fn list_entities_all_filters_emit() {
201 let opts = ListEntitiesOptions::builder()
202 .search("Acme")
203 .cage_code("1ABC5")
204 .naics("541512")
205 .name("Acme Corp")
206 .psc("D302")
207 .purpose_of_registration_code("Z1")
208 .socioeconomic("A5")
209 .state("VA")
210 .total_awards_obligated_gte("100000")
211 .total_awards_obligated_lte("999999")
212 .uei("UEI123456789")
213 .zip_code("22201")
214 .build();
215 let q = opts.to_query();
216 assert_eq!(get_q(&q, "search").as_deref(), Some("Acme"));
217 assert_eq!(get_q(&q, "cage_code").as_deref(), Some("1ABC5"));
218 assert_eq!(get_q(&q, "naics").as_deref(), Some("541512"));
219 assert_eq!(get_q(&q, "name").as_deref(), Some("Acme Corp"));
220 assert_eq!(get_q(&q, "psc").as_deref(), Some("D302"));
221 assert_eq!(
222 get_q(&q, "purpose_of_registration_code").as_deref(),
223 Some("Z1")
224 );
225 assert_eq!(get_q(&q, "socioeconomic").as_deref(), Some("A5"));
226 assert_eq!(get_q(&q, "state").as_deref(), Some("VA"));
227 assert_eq!(
228 get_q(&q, "total_awards_obligated_gte").as_deref(),
229 Some("100000")
230 );
231 assert_eq!(
232 get_q(&q, "total_awards_obligated_lte").as_deref(),
233 Some("999999")
234 );
235 assert_eq!(get_q(&q, "uei").as_deref(), Some("UEI123456789"));
236 assert_eq!(get_q(&q, "zip_code").as_deref(), Some("22201"));
237 }
238
239 #[test]
240 fn list_entities_zero_value_omitted() {
241 let opts = ListEntitiesOptions::builder().build();
242 let q = opts.to_query();
243 assert!(q.is_empty(), "expected empty query, got {q:?}");
244 }
245
246 #[test]
247 fn list_entities_pagination_applied() {
248 let opts = ListEntitiesOptions::builder()
249 .page(2u32)
250 .limit(25u32)
251 .build();
252 let q = opts.to_query();
253 assert_eq!(get_q(&q, "page").as_deref(), Some("2"));
254 assert_eq!(get_q(&q, "limit").as_deref(), Some("25"));
255 }
256
257 #[test]
258 fn get_entity_opts_emits_shape_and_flat() {
259 let opts = GetEntityOptions::builder()
260 .shape(crate::SHAPE_ENTITIES_MINIMAL)
261 .flat(true)
262 .flat_lists(true)
263 .build();
264 let q = opts.to_query();
265 assert!(q.contains(&("shape".into(), crate::SHAPE_ENTITIES_MINIMAL.into())));
266 assert!(q.contains(&("flat".into(), "true".into())));
267 assert!(q.contains(&("flat_lists".into(), "true".into())));
268 }
269
270 #[tokio::test]
271 async fn get_entity_empty_uei_returns_validation() {
272 let client = Client::builder().api_key("x").build().expect("build");
273 let err = client.get_entity("", None).await.expect_err("must error");
274 match err {
275 Error::Validation { message, .. } => assert!(message.contains("uei")),
276 other => panic!("expected Validation, got {other:?}"),
277 }
278 }
279}