darwin_v7/
team.rs

1#[allow(unused_imports)]
2use fake::{Dummy, Fake};
3
4use crate::annotation::AnnotationClass;
5use crate::expect_http_ok;
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use serde_yaml::Value;
9use std::{fmt::Display, path::PathBuf};
10
11use crate::client::V7Methods;
12use crate::errors::DarwinV7Error;
13use crate::workflow::{WorkflowBuilder, WorkflowV2};
14
15#[derive(Debug, Default, PartialEq, Eq, Clone)]
16pub struct Team {
17    pub slug: String,
18    pub datasets_dir: Option<PathBuf>,
19    pub api_key: Option<String>,
20    pub team_id: Option<u32>,
21}
22
23#[derive(Debug, Default, Serialize, Deserialize, Dummy, PartialEq, Eq, Clone)]
24pub struct TeamMember {
25    pub id: Option<u32>,
26    pub email: Option<String>,
27    pub first_name: Option<String>,
28    pub last_name: Option<String>,
29    pub role: Option<String>,
30    pub team_id: Option<u32>,
31    pub user_id: Option<u32>,
32}
33
34impl Display for TeamMember {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        write!(
37            f,
38            "{{id-{}}}{} {} ({})",
39            self.user_id.unwrap_or_default(),
40            self.first_name.as_ref().unwrap_or(&String::new()),
41            self.last_name.as_ref().unwrap_or(&String::new()),
42            self.email.as_ref().unwrap_or(&String::new())
43        )
44    }
45}
46
47#[derive(Debug, Default, Clone, Serialize, Deserialize, Dummy, PartialEq, Eq)]
48pub struct TypeCount {
49    pub count: Option<u32>,
50    pub id: Option<u32>,
51    pub name: Option<String>,
52}
53
54#[derive(Debug, Default, Clone, Serialize, Deserialize, Dummy)]
55pub struct TeamAnnotationClasses {
56    pub annotation_classes: Vec<Option<AnnotationClass>>,
57    pub type_counts: Vec<Option<TypeCount>>,
58}
59
60impl TryFrom<(&Value, &Value)> for Team {
61    type Error = DarwinV7Error;
62
63    fn try_from((team_slug, team): (&Value, &Value)) -> Result<Self, Self::Error> {
64        let slug = team_slug
65            .as_str()
66            .ok_or(DarwinV7Error::InvalidConfigError(
67                "Invalid team slug".to_string(),
68            ))?
69            .to_string();
70        let api_key: Option<String> = match team.get("api_key") {
71            Some(key) => Some(
72                key.as_str()
73                    .ok_or(DarwinV7Error::InvalidConfigError(
74                        "Invalid api-key".to_string(),
75                    ))?
76                    .to_string(),
77            ),
78            None => None,
79        };
80        let datasets_dir: Option<PathBuf> = match team.get("datasets_dir") {
81            Some(key) => Some(PathBuf::from(
82                key.as_str()
83                    .ok_or(DarwinV7Error::InvalidConfigError(
84                        "Invalid datasets_dir".to_string(),
85                    ))?
86                    .to_string(),
87            )),
88            None => None,
89        };
90
91        Ok(Self {
92            slug,
93            datasets_dir,
94            api_key,
95            team_id: None,
96        })
97    }
98}
99
100#[async_trait]
101pub trait TeamDescribeMethods<C>
102where
103    C: V7Methods,
104{
105    async fn list_memberships(client: &C) -> Result<Vec<TeamMember>, DarwinV7Error>;
106    async fn list_annotation_classes(
107        &self,
108        client: &C,
109    ) -> Result<TeamAnnotationClasses, DarwinV7Error>;
110}
111
112#[async_trait]
113pub trait TeamDataMethods<C>
114where
115    C: V7Methods,
116{
117    async fn create_annotation_class(
118        &self,
119        client: &C,
120        class: &AnnotationClass,
121    ) -> Result<AnnotationClass, DarwinV7Error>;
122    // async fn delete_annotation_classes(
123    //     &self,
124    //     client: &C,
125    //     classes: &[AnnotationClass],
126    // ) -> Result<()>;
127}
128
129#[async_trait]
130pub trait TeamWorkflowMethods<C>
131where
132    C: V7Methods,
133{
134    async fn create_workflow(
135        &self,
136        client: &C,
137        workflow: &WorkflowBuilder,
138    ) -> Result<WorkflowV2, DarwinV7Error>;
139}
140
141#[async_trait]
142impl<C> TeamWorkflowMethods<C> for Team
143where
144    C: V7Methods + std::marker::Sync,
145{
146    async fn create_workflow(
147        &self,
148        client: &C,
149        workflow: &WorkflowBuilder,
150    ) -> Result<WorkflowV2, DarwinV7Error>
151    where
152        C: V7Methods,
153    {
154        let response = client
155            .post(&format!("v2/teams/{}/workflows", self.slug), workflow)
156            .await?;
157        // 201 is correct operation for this endpoint
158        if response.status() == 201 {
159            Ok(response.json().await?)
160        } else {
161            Err(DarwinV7Error::HTTPError(
162                response.status(),
163                response.text().await?,
164            ))
165        }
166    }
167}
168
169impl Team {
170    pub fn new(
171        slug: String,
172        datasets_dir: Option<PathBuf>,
173        api_key: Option<String>,
174        team_id: Option<u32>,
175    ) -> Self {
176        Self {
177            slug,
178            datasets_dir,
179            api_key,
180            team_id,
181        }
182    }
183}
184
185#[async_trait]
186impl<C> TeamDescribeMethods<C> for Team
187where
188    C: V7Methods + std::marker::Sync,
189{
190    // This uses the authentication token
191    async fn list_memberships(client: &C) -> Result<Vec<TeamMember>, DarwinV7Error> {
192        let response = client.get("memberships").await?;
193
194        expect_http_ok!(response, Vec<TeamMember>)
195    }
196
197    // Relies upon the team id / slug
198    async fn list_annotation_classes(
199        &self,
200        client: &C,
201    ) -> Result<TeamAnnotationClasses, DarwinV7Error> {
202        // TODO: add query params
203        let endpoint = format!("teams/{}/annotation_classes", self.slug);
204        let response = client.get(&endpoint).await?;
205
206        expect_http_ok!(response, TeamAnnotationClasses)
207    }
208}
209
210#[derive(Debug, Default, Serialize, Deserialize, Dummy, PartialEq, Eq, Clone)]
211struct DeleteClassesPayload {
212    pub annotation_class_ids: Vec<u32>,
213    pub annotations_to_delete_count: u32,
214}
215
216#[async_trait]
217impl<C> TeamDataMethods<C> for Team
218where
219    C: V7Methods + std::marker::Sync,
220{
221    async fn create_annotation_class(
222        &self,
223        client: &C,
224        class: &AnnotationClass,
225    ) -> Result<AnnotationClass, DarwinV7Error>
226    where
227        C: V7Methods,
228    {
229        let endpoint = format!("teams/{}/annotation_classes", self.slug);
230        let response = client.post(&endpoint, class).await?;
231
232        expect_http_ok!(response, AnnotationClass)
233    }
234    // async fn delete_annotation_classes(&self, client: &C, classes: &[AnnotationClass]) -> Result<()>
235    // where
236    //     C: V7Methods,
237    // {
238    //     let endpoint = format!(
239    //         "teams/{}/delete_classes",
240    //         self.team_id.context("Missing team id")?
241    //     );
242
243    //     let mut payload = DeleteClassesPayload::default();
244    //     for class in classes.iter() {
245    //         payload.annotation_class_ids.push(class.id.context(format!(
246    //             "Class {} missing id",
247    //             class.name.clone().unwrap_or(String::new())
248    //         ))?);
249    //     }
250
251    //     let response = client.delete(&endpoint, &payload).await?;
252
253    //     let status = response.status();
254    //     if status != 204 {
255    //         bail!("Unable to delete classes with status code {}", status);
256    //     }
257
258    //     Ok(())
259    // }
260}
261
262#[derive(Debug, Clone, Serialize, Deserialize, Dummy, PartialEq, Eq)]
263pub struct MetadataSkeleton {
264    #[serde(rename = "_type")]
265    pub skeleton_type: String,
266}
267
268pub mod helpers {
269    use crate::{client::V7Methods, errors::DarwinV7Error};
270
271    use super::{Team, TeamDescribeMethods, TeamMember};
272
273    pub async fn find_team_members<C, F>(
274        client: &C,
275        func: F,
276    ) -> Result<Vec<TeamMember>, DarwinV7Error>
277    where
278        C: V7Methods + std::marker::Sync,
279        F: Fn(&TeamMember) -> bool,
280    {
281        Ok(Team::list_memberships(client)
282            .await?
283            .iter()
284            .filter(|x| func(x))
285            .cloned()
286            .collect::<Vec<TeamMember>>())
287    }
288
289    pub async fn find_team_members_by_email<C>(
290        client: &C,
291        email: &str,
292    ) -> Result<Vec<TeamMember>, DarwinV7Error>
293    where
294        C: V7Methods + std::marker::Sync,
295    {
296        find_team_members(client, |x| -> bool {
297            x.email.as_ref().unwrap_or(&String::new()).contains(email)
298        })
299        .await
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    #[test]
308    fn test_from_str_all_fields() {
309        let team_a: &'static str = "team-a:
310        api_key: 1ed99664-726e-4400-bc5d-3132b22ce60c
311        datasets_dir: /home/user/.v7/team-a
312        ";
313
314        let raw_team: serde_yaml::Value = serde_yaml::from_str(team_a).unwrap();
315        let raw_team: Vec<(&Value, &Value)> = raw_team.as_mapping().unwrap().iter().collect();
316
317        let team: Team = Team::try_from(*raw_team.first().unwrap()).unwrap();
318        assert_eq!(team.slug, "team-a".to_string());
319        assert_eq!(
320            team.api_key.as_ref().unwrap(),
321            "1ed99664-726e-4400-bc5d-3132b22ce60c"
322        );
323        assert_eq!(
324            team.datasets_dir.as_ref().unwrap(),
325            &PathBuf::from("/home/user/.v7/team-a")
326        );
327    }
328
329    #[test]
330    fn test_from_str_slug_only() {
331        let raw_team: serde_yaml::Value = serde_yaml::from_str("team-b:\n").unwrap();
332        let raw_team: Vec<(&Value, &Value)> = raw_team.as_mapping().unwrap().iter().collect();
333
334        let team: Team = Team::try_from(*raw_team.first().unwrap()).unwrap();
335        assert_eq!(team.slug, "team-b".to_string());
336        assert_eq!(team.api_key.as_ref(), None);
337        assert_eq!(team.datasets_dir.as_ref(), None);
338    }
339}