1use std::{collections::HashMap, path::Path, str::FromStr, sync::Arc};
2
3use reqwest::{
4 Method, StatusCode,
5 header::{ACCEPT, HeaderMap, HeaderName, InvalidHeaderValue},
6 multipart,
7};
8use serde::Deserialize;
9use tracing::{debug, trace};
10
11use crate::{
12 Error, Result,
13 correspondent::{Correspondent, CorrespondentId},
14 custom_field::{CustomField, CustomFieldId},
15 document::{Document, DocumentId},
16 document_type::{DocumentType, DocumentTypeId},
17 tag::{Tag, TagId},
18 task::{Task, TaskId},
19};
20
21#[derive(Copy, Clone, Debug, PartialEq, Eq)]
23pub enum RefreshData {
24 Tags,
25 CustomFields,
26 Correspondents,
27 DocumentTypes,
28}
29
30#[derive(Debug, Clone)]
32pub struct PaperlessClient {
33 client: reqwest::Client,
34 base_url: String,
35
36 correspondents: HashMap<CorrespondentId, Correspondent>,
37 document_types: HashMap<DocumentTypeId, DocumentType>,
38 tags: HashMap<TagId, Tag>,
39 custom_fields: HashMap<CustomFieldId, CustomField>,
40}
41
42#[derive(Debug, Deserialize)]
43struct PaginatedResponse<T> {
44 results: Vec<T>,
45 next: Option<String>,
46}
47
48impl PaperlessClient {
49 pub fn new(
51 base_url: &str,
52 token: &str,
53 headers: Option<&HashMap<String, String>>,
54 ) -> std::result::Result<Self, String> {
55 let mut headers_map = HeaderMap::new();
56
57 if let Some(headers) = headers {
59 for (key, value) in headers {
60 headers_map.insert(
61 HeaderName::from_str(key).map_err(|err| err.to_string())?,
62 value
63 .parse()
64 .map_err(|err: InvalidHeaderValue| err.to_string())?,
65 );
66 }
67 }
68
69 headers_map.insert(
71 HeaderName::from_str("Authorization").map_err(|err| err.to_string())?,
72 format!("Token {token}")
73 .parse()
74 .map_err(|err: InvalidHeaderValue| err.to_string())?,
75 );
76
77 Ok(Self {
78 base_url: base_url.to_string(),
79 client: reqwest::Client::builder()
80 .default_headers(headers_map)
81 .build()
82 .map_err(|err| err.to_string())?,
83 tags: HashMap::new(),
84 custom_fields: HashMap::new(),
85 correspondents: HashMap::new(),
86 document_types: HashMap::new(),
87 })
88 }
89
90 async fn load_tags(&self) -> Result<HashMap<TagId, Tag>> {
91 debug!("loading tags");
92 let tags: Vec<Tag> = self.fetch_all_pages("/api/tags/").await?;
93 Ok(tags.into_iter().map(|tag| (tag.id, tag)).collect())
94 }
95
96 async fn load_custom_fields(&self) -> Result<HashMap<CustomFieldId, CustomField>> {
97 debug!("loading custom fields");
98 let custom_fields: Vec<CustomField> = self.fetch_all_pages("/api/custom_fields/").await?;
99 Ok(custom_fields
100 .into_iter()
101 .map(|custom_field| (custom_field.id, custom_field))
102 .collect())
103 }
104
105 async fn load_correspondents(&self) -> Result<HashMap<CorrespondentId, Correspondent>> {
106 debug!("loading correspondents");
107 let correspondents: Vec<Correspondent> =
108 self.fetch_all_pages("/api/correspondents/").await?;
109 Ok(correspondents
110 .into_iter()
111 .map(|correspondent| (correspondent.id, correspondent))
112 .collect())
113 }
114
115 async fn load_document_types(&self) -> Result<HashMap<DocumentTypeId, DocumentType>> {
116 debug!("loading document types");
117 let document_types: Vec<DocumentType> =
118 self.fetch_all_pages("/api/document_types/").await?;
119 Ok(document_types
120 .into_iter()
121 .map(|document_type| (document_type.id, document_type))
122 .collect())
123 }
124
125 pub async fn refresh(&mut self, data: impl IntoIterator<Item = RefreshData>) -> Result<()> {
127 let mut refresh_tags = false;
128 let mut refresh_custom_fields = false;
129 let mut refresh_correspondents = false;
130 let mut refresh_document_types = false;
131
132 for item in data {
133 match item {
134 RefreshData::Tags => refresh_tags = true,
135 RefreshData::CustomFields => refresh_custom_fields = true,
136 RefreshData::Correspondents => refresh_correspondents = true,
137 RefreshData::DocumentTypes => refresh_document_types = true,
138 }
139 }
140
141 let (tags, custom_fields, correspondents, document_types) = futures_util::try_join!(
142 async {
143 if refresh_tags {
144 Ok::<Option<HashMap<TagId, Tag>>, Error>(Some(self.load_tags().await?))
145 } else {
146 Ok::<Option<HashMap<TagId, Tag>>, Error>(None)
147 }
148 },
149 async {
150 if refresh_custom_fields {
151 Ok::<Option<HashMap<CustomFieldId, CustomField>>, Error>(Some(
152 self.load_custom_fields().await?,
153 ))
154 } else {
155 Ok::<Option<HashMap<CustomFieldId, CustomField>>, Error>(None)
156 }
157 },
158 async {
159 if refresh_correspondents {
160 Ok::<Option<HashMap<CorrespondentId, Correspondent>>, Error>(Some(
161 self.load_correspondents().await?,
162 ))
163 } else {
164 Ok::<Option<HashMap<CorrespondentId, Correspondent>>, Error>(None)
165 }
166 },
167 async {
168 if refresh_document_types {
169 Ok::<Option<HashMap<DocumentTypeId, DocumentType>>, Error>(Some(
170 self.load_document_types().await?,
171 ))
172 } else {
173 Ok::<Option<HashMap<DocumentTypeId, DocumentType>>, Error>(None)
174 }
175 },
176 )?;
177
178 if let Some(tags) = tags {
179 self.tags = tags;
180 }
181
182 if let Some(custom_fields) = custom_fields {
183 self.custom_fields = custom_fields;
184 }
185
186 if let Some(correspondents) = correspondents {
187 self.correspondents = correspondents;
188 }
189
190 if let Some(document_types) = document_types {
191 self.document_types = document_types;
192 }
193
194 Ok(())
195 }
196
197 #[inline]
199 pub async fn refresh_tags(&mut self) -> Result<()> {
200 self.refresh([RefreshData::Tags]).await
201 }
202
203 #[inline]
205 pub async fn refresh_custom_fields(&mut self) -> Result<()> {
206 self.refresh([RefreshData::CustomFields]).await
207 }
208
209 #[inline]
211 pub async fn refresh_correspondents(&mut self) -> Result<()> {
212 self.refresh([RefreshData::Correspondents]).await
213 }
214
215 #[inline]
217 pub async fn refresh_document_types(&mut self) -> Result<()> {
218 self.refresh([RefreshData::DocumentTypes]).await
219 }
220
221 pub async fn get_documents_by_tags(
223 &self,
224 tag_ids: &[TagId],
225 truncate_content: bool,
226 ) -> Result<Vec<Document>> {
227 let tag_id_str = tag_ids
228 .iter()
229 .map(|tag_id| tag_id.0.to_string())
230 .collect::<Vec<_>>()
231 .join(",");
232 let mut documents: Vec<Document> = self
233 .fetch_all_pages(&format!(
234 "/api/documents/?truncate_content={truncate_content}&tags__id__in={tag_id_str}"
235 ))
236 .await?;
237
238 for doc in &mut documents {
239 doc.client = Some(Arc::new(self.clone()));
240 doc.content_is_truncated = truncate_content;
241 }
242
243 Ok(documents)
244 }
245
246 pub async fn get_document_by_id(&self, id: DocumentId) -> Result<Document> {
248 let resp = self
249 .request(Method::GET, &format!("/api/documents/{}/", id.0), None)
250 .await?;
251
252 let mut document: Document = resp
253 .json()
254 .await
255 .map_err(|e| Error::Other(format!("Failed to parse document: {e}")))?;
256
257 document.client = Some(Arc::new(self.clone()));
258 Ok(document)
259 }
260
261 pub(crate) async fn request(
262 &self,
263 method: Method,
264 endpoint: &str,
265 body: Option<&serde_json::Value>,
266 ) -> Result<reqwest::Response> {
267 let mut req = self
268 .client
269 .request(method, format!("{}{endpoint}", self.base_url))
270 .header(ACCEPT, "application/json");
271
272 if let Some(json_body) = body {
273 req = req.json(json_body);
274 }
275
276 let resp = req
277 .send()
278 .await
279 .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
280
281 if resp.status() == StatusCode::NOT_FOUND {
282 return Err(Error::NotFound);
283 }
284
285 if !resp.status().is_success() {
286 return Err(Error::Response {
287 status_code: resp.status().as_u16(),
288 body: resp.text().await.unwrap_or_default(),
289 });
290 }
291
292 Ok(resp)
293 }
294
295 pub(crate) async fn fetch_all_pages<T: for<'de> Deserialize<'de>>(
296 &self,
297 endpoint: &str,
298 ) -> Result<Vec<T>> {
299 let mut results = Vec::new();
300 let mut current_url = Some(endpoint.to_string());
301
302 while let Some(url) = current_url {
303 let resp = self.request(Method::GET, &url, None).await?;
304
305 let page: PaginatedResponse<T> = resp.json().await.map_err(|e| {
306 Error::Other(format!(
307 "Failed to parse paginated response for {endpoint}: {e}"
308 ))
309 })?;
310
311 results.extend(page.results);
312
313 current_url = page.next.and_then(|next_url| {
314 next_url
316 .trim_start_matches(&self.base_url)
317 .to_string()
318 .into()
319 });
320 }
321
322 Ok(results)
323 }
324
325 pub async fn get_task_status(
327 &self,
328 task_id: Option<&TaskId>,
329 task_name: Option<&str>,
330 acknowledged: Option<bool>,
331 ) -> Result<Vec<Task>> {
332 let mut query = Vec::new();
333
334 if let Some(id) = task_id {
335 query.push(("task_id", id.to_string()));
336 }
337
338 if let Some(name) = task_name {
339 query.push(("task_name", name.to_string()));
340 }
341
342 if let Some(ack) = acknowledged {
343 query.push(("acknowledged", ack.to_string()));
344 }
345
346 let resp = self
347 .request(
348 Method::GET,
349 &format!(
350 "/api/tasks/?{}",
351 serde_urlencoded::to_string(&query)
352 .map_err(|e| Error::Other(format!("Failed to serialize query: {e}")))?
353 ),
354 None,
355 )
356 .await?;
357
358 trace!("get_task_status response: {:?}", resp);
359
360 let body = resp
361 .text()
362 .await
363 .map_err(|e| Error::Other(format!("Failed to read response body: {e}")))?;
364
365 let tasks: Vec<Task> = match serde_json::from_str(&body) {
366 Ok(t) => t,
367 Err(e) => {
368 return Err(Error::InvalidJson(format!(
369 "Failed to parse response body: {e}"
370 )));
371 }
372 };
373
374 if tasks.is_empty() {
375 return Err(Error::NotFound);
376 }
377
378 Ok(tasks)
379 }
380
381 pub async fn upload_document(&self, file_path: &Path, filename: &str) -> Result<TaskId> {
385 let file_bytes = std::fs::read(file_path)
386 .map_err(|e| Error::Other(format!("Failed to read file: {e}")))?;
387
388 let form = multipart::Form::new().part(
389 "document",
390 multipart::Part::bytes(file_bytes).file_name(filename.to_string()),
391 );
392
393 let url = format!("{}/api/documents/post_document/", self.base_url);
394
395 let resp = self
396 .client
397 .post(&url)
398 .multipart(form)
399 .send()
400 .await
401 .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
402
403 let status = resp.status();
404 if !resp.status().is_success() {
405 return Err(Error::Response {
406 status_code: status.as_u16(),
407 body: resp.text().await.unwrap_or_default(),
408 });
409 }
410
411 let task_id: String = resp
412 .json()
413 .await
414 .map_err(|e| Error::Other(format!("Failed to parse task ID: {e}")))?;
415 Ok(TaskId(task_id))
416 }
417
418 pub fn tags(&self) -> &HashMap<TagId, Tag> {
419 &self.tags
420 }
421
422 pub fn find_tag_by_name(&self, name: &str) -> Option<&Tag> {
423 self.tags.values().find(|tag| tag.name == name)
424 }
425
426 pub fn document_types(&self) -> &HashMap<DocumentTypeId, DocumentType> {
427 &self.document_types
428 }
429
430 pub fn correspondents(&self) -> &HashMap<CorrespondentId, Correspondent> {
431 &self.correspondents
432 }
433
434 pub fn custom_fields(&self) -> &HashMap<CustomFieldId, CustomField> {
435 &self.custom_fields
436 }
437
438 pub fn find_custom_field_by_name(&self, name: &str) -> Option<&CustomField> {
439 self.custom_fields.values().find(|field| field.name == name)
440 }
441}