Skip to main content

hypertune/
edge.rs

1use std::collections::HashMap;
2
3use anyhow::anyhow;
4use clap::ValueEnum;
5use serde::{Deserialize, Serialize};
6use serde_json::json;
7use serde_json::Value;
8
9use crate::constants;
10use crate::expression::Expression;
11use crate::split::CommitConfig;
12use crate::split::SplitMap;
13use crate::types::InitQuery;
14use crate::types::Message;
15
16#[derive(Deserialize, Serialize, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
17#[serde(rename_all = "lowercase")]
18pub enum Language {
19    Python,
20    Rust,
21    Go,
22}
23
24#[derive(Deserialize)]
25pub struct CodegenFile {
26    pub name: String,
27    pub content: String,
28}
29#[derive(Deserialize)]
30pub struct CodegenResponse {
31    pub files: Vec<CodegenFile>,
32    pub messages: Vec<Message>,
33}
34
35#[derive(Deserialize, Debug, Eq, PartialEq)]
36#[serde(rename_all = "camelCase")]
37pub struct InitializationData {
38    pub reduced_expression: Expression,
39    pub hash: String,
40    pub commit_config: CommitConfig,
41    pub splits: SplitMap,
42    pub commit_id: u64,
43}
44pub struct HashResponse {
45    pub commit_id: u64,
46    pub hash: String,
47}
48
49pub async fn codegen_request(
50    token: &str,
51    branch_name: Option<String>,
52    language: Language,
53    sdk_version: String,
54    query: Option<&str>,
55    include_token: bool,
56    include_fallback: bool,
57    base_url: &str,
58) -> Result<CodegenResponse, reqwest::Error> {
59    let body = json!({
60        "query": query,
61        "includeToken": include_token,
62        "includeFallback": include_fallback,
63        "sdkType": language,
64        "sdkVersion": sdk_version,
65        "language": language,
66    })
67    .to_string();
68    let mut params = HashMap::new();
69    params.insert("token", token);
70    params.insert("body", &body);
71    if let Some(ref branch_name) = branch_name {
72        params.insert("branch", branch_name.as_str());
73    }
74
75    let client = reqwest::Client::new();
76    client
77        .get(format!("{}/codegen", base_url))
78        .query(&params)
79        .send()
80        .await?
81        .error_for_status()?
82        .json()
83        .await
84}
85
86pub async fn init_request(
87    token: &str,
88    branch_name: Option<String>,
89    query: &InitQuery,
90    variables: &Value,
91    language: Language,
92    base_url: &str,
93) -> Result<InitializationData, reqwest::Error> {
94    let body = &json!({
95        "query": query,
96        "variables": variables,
97        "sdkType": language,
98        "sdkVersion": constants::VERSION,
99    })
100    .to_string();
101    let mut params = HashMap::new();
102    params.insert("token", token);
103    params.insert("body", body);
104    if let Some(ref branch_name) = branch_name {
105        params.insert("branch", branch_name.as_str());
106    }
107
108    let client = reqwest::Client::new();
109    client
110        .get(format!("{}/init", base_url))
111        .query(&params)
112        .send()
113        .await?
114        .error_for_status()?
115        .json()
116        .await
117}
118
119pub async fn hash_request(
120    token: &str,
121    branch_name: Option<String>,
122    query: &InitQuery,
123    variables: &Value,
124    language: Language,
125    base_url: &str,
126) -> Result<HashResponse, anyhow::Error> {
127    let body = &json!({
128        "query": query,
129        "variables": variables,
130        "sdkType": language,
131        "sdkVersion": constants::VERSION,
132    })
133    .to_string();
134    let mut params = HashMap::new();
135    params.insert("token", token);
136    params.insert("body", body);
137    if let Some(ref branch_name) = branch_name {
138        params.insert("branch", branch_name.as_str());
139    }
140
141    let client = reqwest::Client::new();
142    let text = client
143        .get(format!("{}/hash", base_url))
144        .query(&params)
145        .send()
146        .await?
147        .error_for_status()?
148        .text()
149        .await?;
150
151    if let Some((raw_commit_id, hash)) = text.split_once('_') {
152        let commit_id = raw_commit_id.parse::<u64>()?;
153        Ok(HashResponse {
154            commit_id,
155            hash: hash.to_string(),
156        })
157    } else {
158        Err(anyhow!("Invalid hash response: {}", text))
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use crate::expression::BooleanExpression;
166    use crate::expression::Logs;
167    use crate::expression::ObjectExpression;
168    use std::collections::HashMap;
169
170    #[test]
171    fn test_deserialization() {
172        let json = r#"
173        {
174            "commitId": 3535,
175            "hash": "7719744257597328",
176            "reducedExpression": {
177              "id": "ih422aM1N7w5u1pZHjiHj",
178              "type": "ObjectExpression",
179              "fields": {
180                "root": {
181                  "id": "gZDqtndJoiI8RQ7q4CuCo",
182                  "logs": {
183                    "evaluations": {
184                      "gy-OTD4_b0zH2XOmaXctc": 1
185                    }
186                  },
187                  "type": "ObjectExpression",
188                  "fields": {
189                    "simpleFlag": {
190                      "id": "3Py1edtYsW_dEqdwEewak",
191                      "logs": {
192                        "evaluations": {
193                          "WO2MDkDuq0TMU5_MvKSe6": 1
194                        }
195                      },
196                      "type": "BooleanExpression",
197                      "value": false,
198                      "valueType": {
199                        "type": "BooleanValueType"
200                      }
201                    }
202                  }, 
203                  "valueType": {
204                    "type": "ObjectValueType",
205                    "objectTypeName": "Root"
206                  },
207                  "objectTypeName": "Root"
208                }
209              },
210              "metadata": {
211                "permissions": {
212                  "user": {},
213                  "group": {
214                    "team": {
215                      "write": "allow"
216                    }
217                  }
218                }
219              },
220              "valueType": {
221                "type": "ObjectValueType",
222                "objectTypeName": "Query"
223              },
224              "objectTypeName": "Query"
225            },
226            "splits": {},
227            "commitConfig": {
228              "splitConfig": {}
229            }
230          }
231"#;
232
233        let deserialized: InitializationData = serde_json::from_str(json).unwrap();
234
235        let expected = InitializationData {
236            commit_id: 3535,
237            hash: "7719744257597328".to_string(),
238            commit_config: CommitConfig::new(),
239            splits: SplitMap::new(),
240            reduced_expression: Expression::Object(ObjectExpression {
241                id: "ih422aM1N7w5u1pZHjiHj".to_string(),
242                object_type_name: "Query".to_string(),
243                logs: None,
244                is_transient: false,
245                fields: HashMap::from([(
246                    "root".to_string(),
247                    Expression::Object(ObjectExpression {
248                        id: "gZDqtndJoiI8RQ7q4CuCo".to_string(),
249                        object_type_name: "Root".to_string(),
250                        is_transient: false,
251                        logs: Some(Logs {
252                            evaluations: Some(HashMap::from([(
253                                "gy-OTD4_b0zH2XOmaXctc".to_string(),
254                                1,
255                            )])),
256                            event_list: None,
257                            exposure_list: None,
258                        }),
259                        fields: HashMap::from([(
260                            "simpleFlag".to_string(),
261                            Expression::Boolean(BooleanExpression {
262                                id: "3Py1edtYsW_dEqdwEewak".to_string(),
263                                value: false,
264                                is_transient: false,
265                                logs: Some(Logs {
266                                    evaluations: Some(HashMap::from([(
267                                        "WO2MDkDuq0TMU5_MvKSe6".to_string(),
268                                        1,
269                                    )])),
270                                    event_list: None,
271                                    exposure_list: None,
272                                }),
273                            }),
274                        )]),
275                    }),
276                )]),
277            }),
278        };
279
280        assert_eq!(deserialized, expected)
281    }
282}