1use std::collections::HashMap;
2
3use anyhow::anyhow;
4use clap::ValueEnum;
5use serde::{Deserialize, Serialize};
6use serde_json::json;
7use serde_json::Value;
8
9use crate::constants;
10use crate::expression::Expression;
11use crate::split::CommitConfig;
12use crate::split::SplitMap;
13use crate::types::InitQuery;
14use crate::types::Message;
15
16#[derive(Deserialize, Serialize, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
17#[serde(rename_all = "lowercase")]
18pub enum Language {
19 Python,
20 Rust,
21 Go,
22}
23
24#[derive(Deserialize)]
25pub struct CodegenFile {
26 pub name: String,
27 pub content: String,
28}
29#[derive(Deserialize)]
30pub struct CodegenResponse {
31 pub files: Vec<CodegenFile>,
32 pub messages: Vec<Message>,
33}
34
35#[derive(Deserialize, Debug, Eq, PartialEq)]
36#[serde(rename_all = "camelCase")]
37pub struct InitializationData {
38 pub reduced_expression: Expression,
39 pub hash: String,
40 pub commit_config: CommitConfig,
41 pub splits: SplitMap,
42 pub commit_id: u64,
43}
44pub struct HashResponse {
45 pub commit_id: u64,
46 pub hash: String,
47}
48
49pub async fn codegen_request(
50 token: &str,
51 branch_name: Option<String>,
52 language: Language,
53 sdk_version: String,
54 query: Option<&str>,
55 include_token: bool,
56 include_fallback: bool,
57 base_url: &str,
58) -> Result<CodegenResponse, reqwest::Error> {
59 let body = json!({
60 "query": query,
61 "includeToken": include_token,
62 "includeFallback": include_fallback,
63 "sdkType": language,
64 "sdkVersion": sdk_version,
65 "language": language,
66 })
67 .to_string();
68 let mut params = HashMap::new();
69 params.insert("token", token);
70 params.insert("body", &body);
71 if let Some(ref branch_name) = branch_name {
72 params.insert("branch", branch_name.as_str());
73 }
74
75 let client = reqwest::Client::new();
76 client
77 .get(format!("{}/codegen", base_url))
78 .query(¶ms)
79 .send()
80 .await?
81 .error_for_status()?
82 .json()
83 .await
84}
85
86pub async fn init_request(
87 token: &str,
88 branch_name: Option<String>,
89 query: &InitQuery,
90 variables: &Value,
91 language: Language,
92 base_url: &str,
93) -> Result<InitializationData, reqwest::Error> {
94 let body = &json!({
95 "query": query,
96 "variables": variables,
97 "sdkType": language,
98 "sdkVersion": constants::VERSION,
99 })
100 .to_string();
101 let mut params = HashMap::new();
102 params.insert("token", token);
103 params.insert("body", body);
104 if let Some(ref branch_name) = branch_name {
105 params.insert("branch", branch_name.as_str());
106 }
107
108 let client = reqwest::Client::new();
109 client
110 .get(format!("{}/init", base_url))
111 .query(¶ms)
112 .send()
113 .await?
114 .error_for_status()?
115 .json()
116 .await
117}
118
119pub async fn hash_request(
120 token: &str,
121 branch_name: Option<String>,
122 query: &InitQuery,
123 variables: &Value,
124 language: Language,
125 base_url: &str,
126) -> Result<HashResponse, anyhow::Error> {
127 let body = &json!({
128 "query": query,
129 "variables": variables,
130 "sdkType": language,
131 "sdkVersion": constants::VERSION,
132 })
133 .to_string();
134 let mut params = HashMap::new();
135 params.insert("token", token);
136 params.insert("body", body);
137 if let Some(ref branch_name) = branch_name {
138 params.insert("branch", branch_name.as_str());
139 }
140
141 let client = reqwest::Client::new();
142 let text = client
143 .get(format!("{}/hash", base_url))
144 .query(¶ms)
145 .send()
146 .await?
147 .error_for_status()?
148 .text()
149 .await?;
150
151 if let Some((raw_commit_id, hash)) = text.split_once('_') {
152 let commit_id = raw_commit_id.parse::<u64>()?;
153 Ok(HashResponse {
154 commit_id,
155 hash: hash.to_string(),
156 })
157 } else {
158 Err(anyhow!("Invalid hash response: {}", text))
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use crate::expression::BooleanExpression;
166 use crate::expression::Logs;
167 use crate::expression::ObjectExpression;
168 use std::collections::HashMap;
169
170 #[test]
171 fn test_deserialization() {
172 let json = r#"
173 {
174 "commitId": 3535,
175 "hash": "7719744257597328",
176 "reducedExpression": {
177 "id": "ih422aM1N7w5u1pZHjiHj",
178 "type": "ObjectExpression",
179 "fields": {
180 "root": {
181 "id": "gZDqtndJoiI8RQ7q4CuCo",
182 "logs": {
183 "evaluations": {
184 "gy-OTD4_b0zH2XOmaXctc": 1
185 }
186 },
187 "type": "ObjectExpression",
188 "fields": {
189 "simpleFlag": {
190 "id": "3Py1edtYsW_dEqdwEewak",
191 "logs": {
192 "evaluations": {
193 "WO2MDkDuq0TMU5_MvKSe6": 1
194 }
195 },
196 "type": "BooleanExpression",
197 "value": false,
198 "valueType": {
199 "type": "BooleanValueType"
200 }
201 }
202 },
203 "valueType": {
204 "type": "ObjectValueType",
205 "objectTypeName": "Root"
206 },
207 "objectTypeName": "Root"
208 }
209 },
210 "metadata": {
211 "permissions": {
212 "user": {},
213 "group": {
214 "team": {
215 "write": "allow"
216 }
217 }
218 }
219 },
220 "valueType": {
221 "type": "ObjectValueType",
222 "objectTypeName": "Query"
223 },
224 "objectTypeName": "Query"
225 },
226 "splits": {},
227 "commitConfig": {
228 "splitConfig": {}
229 }
230 }
231"#;
232
233 let deserialized: InitializationData = serde_json::from_str(json).unwrap();
234
235 let expected = InitializationData {
236 commit_id: 3535,
237 hash: "7719744257597328".to_string(),
238 commit_config: CommitConfig::new(),
239 splits: SplitMap::new(),
240 reduced_expression: Expression::Object(ObjectExpression {
241 id: "ih422aM1N7w5u1pZHjiHj".to_string(),
242 object_type_name: "Query".to_string(),
243 logs: None,
244 is_transient: false,
245 fields: HashMap::from([(
246 "root".to_string(),
247 Expression::Object(ObjectExpression {
248 id: "gZDqtndJoiI8RQ7q4CuCo".to_string(),
249 object_type_name: "Root".to_string(),
250 is_transient: false,
251 logs: Some(Logs {
252 evaluations: Some(HashMap::from([(
253 "gy-OTD4_b0zH2XOmaXctc".to_string(),
254 1,
255 )])),
256 event_list: None,
257 exposure_list: None,
258 }),
259 fields: HashMap::from([(
260 "simpleFlag".to_string(),
261 Expression::Boolean(BooleanExpression {
262 id: "3Py1edtYsW_dEqdwEewak".to_string(),
263 value: false,
264 is_transient: false,
265 logs: Some(Logs {
266 evaluations: Some(HashMap::from([(
267 "WO2MDkDuq0TMU5_MvKSe6".to_string(),
268 1,
269 )])),
270 event_list: None,
271 exposure_list: None,
272 }),
273 }),
274 )]),
275 }),
276 )]),
277 }),
278 };
279
280 assert_eq!(deserialized, expected)
281 }
282}