xai_rust/api/
documents.rs1use serde::{Deserialize, Serialize};
4
5use crate::client::XaiClient;
6use crate::{Error, Result};
7
8#[derive(Debug, Clone)]
10pub struct DocumentsApi {
11 client: XaiClient,
12}
13
14impl DocumentsApi {
15 pub(crate) fn new(client: XaiClient) -> Self {
16 Self { client }
17 }
18
19 pub async fn search(&self, request: SearchRequest) -> Result<SearchResponse> {
21 let url = format!("{}/documents/search", self.client.base_url());
22
23 let response = self
24 .client
25 .send(self.client.http().post(&url).json(&request))
26 .await?;
27
28 if !response.status().is_success() {
29 return Err(Error::from_response(response).await);
30 }
31
32 Ok(response.json().await?)
33 }
34}
35
36#[derive(Debug, Clone, Serialize)]
38pub struct SearchRequest {
39 pub query: String,
41 #[serde(skip_serializing_if = "Option::is_none")]
43 pub limit: Option<u32>,
44 #[serde(skip_serializing_if = "Option::is_none")]
46 pub score_threshold: Option<f32>,
47}
48
49impl SearchRequest {
50 pub fn new(query: impl Into<String>) -> Self {
52 Self {
53 query: query.into(),
54 limit: None,
55 score_threshold: None,
56 }
57 }
58
59 pub fn limit(mut self, value: u32) -> Self {
61 self.limit = Some(value);
62 self
63 }
64
65 pub fn score_threshold(mut self, value: f32) -> Self {
67 self.score_threshold = Some(value);
68 self
69 }
70}
71
72#[derive(Debug, Clone, Deserialize)]
74pub struct SearchResponse {
75 pub results: Vec<SearchResult>,
77}
78
79#[derive(Debug, Clone, Deserialize)]
81pub struct SearchResult {
82 pub document: SearchDocument,
84 pub score: f32,
86}
87
88#[derive(Debug, Clone, Deserialize)]
90pub struct SearchDocument {
91 pub id: String,
93 pub content: String,
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100 use wiremock::matchers::{method, path};
101 use wiremock::{Mock, MockServer, ResponseTemplate};
102
103 #[tokio::test]
104 async fn search_posts_to_documents_search_endpoint() {
105 let server = MockServer::start().await;
106
107 Mock::given(method("POST"))
108 .and(path("/documents/search"))
109 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
110 "results": [{
111 "document": {
112 "id": "doc-1",
113 "content": "hello"
114 },
115 "score": 0.97
116 }]
117 })))
118 .mount(&server)
119 .await;
120
121 let client = crate::client::XaiClient::builder()
122 .api_key("test-key")
123 .base_url(server.uri())
124 .build()
125 .unwrap();
126
127 let response = client
128 .documents()
129 .search(SearchRequest::new("hello"))
130 .await
131 .unwrap();
132
133 assert_eq!(response.results.len(), 1);
134 assert_eq!(response.results[0].document.id, "doc-1");
135 assert!((response.results[0].score - 0.97).abs() < 0.0001);
136 }
137}