Skip to main content

edgehog_device_runtime_containers/store/
network.rs

1// This file is part of Edgehog.
2//
3// Copyright 2025 SECO Mind Srl
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9//    http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16//
17// SPDX-License-Identifier: Apache-2.0
18
19use diesel::query_dsl::methods::{FilterDsl, SelectDsl};
20use diesel::{
21    delete, insert_or_ignore_into, update, ExpressionMethods, OptionalExtension, RunQueryDsl,
22};
23use edgehog_store::conversions::SqlUuid;
24use edgehog_store::db::HandleError;
25use edgehog_store::models::containers::container::ContainerMissingNetwork;
26use edgehog_store::models::containers::network::{Network, NetworkDriverOpts, NetworkStatus};
27use edgehog_store::models::QueryModel;
28use edgehog_store::schema::containers::{container_networks, network_driver_opts, networks};
29use itertools::Itertools;
30use tracing::instrument;
31use uuid::Uuid;
32
33use crate::resource::network::NetworkResource;
34use crate::{docker::network::Network as ContainerNetwork, requests::network::CreateNetwork};
35
36use super::{split_key_value, Result, StateStore, StoreError};
37
38impl StateStore {
39    /// Stores the network received from the CreateRequest
40    #[instrument(skip_all, fields(%create_network.id))]
41    pub(crate) async fn create_network(&self, create_network: CreateNetwork) -> Result<()> {
42        let opts = Vec::<NetworkDriverOpts>::try_from(&create_network)?;
43        let network = Network::from(create_network);
44
45        self.handle
46            .for_write(move |writer| {
47                insert_or_ignore_into(networks::table)
48                    .values(&network)
49                    .execute(writer)?;
50
51                insert_or_ignore_into(network_driver_opts::table)
52                    .values(opts)
53                    .execute(writer)?;
54
55                insert_or_ignore_into(container_networks::table)
56                    .values(ContainerMissingNetwork::find_by_network(&network.id))
57                    .execute(writer)?;
58
59                delete(ContainerMissingNetwork::find_by_network(&network.id)).execute(writer)?;
60
61                Ok(())
62            })
63            .await?;
64
65        Ok(())
66    }
67
68    /// Updates the state of a network
69    #[instrument(skip(self))]
70    pub(crate) async fn update_network_status(
71        &self,
72        network_id: Uuid,
73        status: NetworkStatus,
74    ) -> Result<()> {
75        self.handle
76            .for_write(move |writer| {
77                let updated = update(Network::find_id(&SqlUuid::new(network_id)))
78                    .set(networks::status.eq(status))
79                    .execute(writer)?;
80
81                HandleError::check_modified(updated, 1)?;
82
83                Ok(())
84            })
85            .await?;
86
87        Ok(())
88    }
89
90    /// Updates the local id of a [`Network`]
91    #[instrument(skip(self))]
92    pub(crate) async fn update_network_local_id(
93        &self,
94        network_id: Uuid,
95        local_id: Option<String>,
96    ) -> Result<()> {
97        self.handle
98            .for_write(move |writer| {
99                let updated = update(Network::find_id(&SqlUuid::new(network_id)))
100                    .set(networks::local_id.eq(local_id))
101                    .execute(writer)?;
102
103                HandleError::check_modified(updated, 1)?;
104
105                Ok(())
106            })
107            .await?;
108
109        Ok(())
110    }
111
112    /// Deletes a [`Network`]
113    #[instrument(skip(self))]
114    pub(crate) async fn delete_network(&self, network_id: Uuid) -> Result<()> {
115        self.handle
116            .for_write(move |writer| {
117                let updated =
118                    delete(Network::find_id(&SqlUuid::new(network_id))).execute(writer)?;
119
120                HandleError::check_modified(updated, 1)?;
121
122                Ok(())
123            })
124            .await?;
125
126        Ok(())
127    }
128
129    #[instrument(skip(self))]
130    pub(crate) async fn load_networks_to_publish(&self) -> Result<Vec<SqlUuid>> {
131        let networks = self
132            .handle
133            .for_read(move |reader| {
134                let networks = networks::table
135                    .select(networks::id)
136                    .filter(networks::status.eq(NetworkStatus::Received))
137                    .load::<SqlUuid>(reader)?;
138
139                Ok(networks)
140            })
141            .await?;
142
143        Ok(networks)
144    }
145
146    #[instrument(skip(self))]
147    pub(crate) async fn find_network(&self, network_id: Uuid) -> Result<Option<NetworkResource>> {
148        let network = self
149            .handle
150            .for_read(move |reader| {
151                let id = SqlUuid::new(network_id);
152                let Some(network) = Network::find_id(&id).first::<Network>(reader).optional()?
153                else {
154                    return Ok(None);
155                };
156
157                let driver_opts = network_driver_opts::table
158                    .filter(network_driver_opts::network_id.eq(id))
159                    .load::<NetworkDriverOpts>(reader)?
160                    .into_iter()
161                    .map(|opt| (opt.name, opt.value))
162                    .collect();
163
164                Ok(Some(NetworkResource::new(ContainerNetwork::new(
165                    network.local_id,
166                    *network.id,
167                    network.driver,
168                    network.internal,
169                    network.enable_ipv6,
170                    driver_opts,
171                ))))
172            })
173            .await?;
174
175        Ok(network)
176    }
177
178    /// Finds the unique id of the network with the given local id
179    ///
180    /// Returns the id of the network and the reference.
181    #[instrument(skip(self))]
182    pub(crate) async fn find_network_by_local_id(&self, local_id: String) -> Result<Option<Uuid>> {
183        let id = self
184            .handle
185            .for_read(|reader| {
186                networks::table
187                    .filter(networks::local_id.eq(local_id))
188                    .select(networks::id)
189                    .first::<SqlUuid>(reader)
190                    .map(|id| *id)
191                    .optional()
192                    .map_err(HandleError::Query)
193            })
194            .await?;
195
196        Ok(id)
197    }
198}
199
200impl From<CreateNetwork> for Network {
201    fn from(
202        CreateNetwork {
203            id,
204            deployment_id: _,
205            driver,
206            internal,
207            enable_ipv6,
208            options: _,
209        }: CreateNetwork,
210    ) -> Self {
211        Self {
212            id: SqlUuid::new(id),
213            local_id: None,
214            status: NetworkStatus::default(),
215            driver: driver.to_string(),
216            internal,
217            enable_ipv6,
218        }
219    }
220}
221
222// Takes the options
223impl TryFrom<&CreateNetwork> for Vec<NetworkDriverOpts> {
224    type Error = StoreError;
225
226    fn try_from(value: &CreateNetwork) -> std::result::Result<Self, Self::Error> {
227        let network_id = SqlUuid::new(value.id);
228
229        value
230            .options
231            .iter()
232            .map(|s| {
233                split_key_value(s)
234                    .map(|(name, value)| NetworkDriverOpts {
235                        network_id,
236                        name: name.to_string(),
237                        value: value.unwrap_or_default().to_string(),
238                    })
239                    .ok_or(StoreError::ParseKeyValue {
240                        ctx: "network driver options",
241                        value: s.to_string(),
242                    })
243            })
244            .try_collect()
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use crate::requests::ReqUuid;
251
252    use super::*;
253
254    use edgehog_store::db;
255    use pretty_assertions::assert_eq;
256    use tempfile::TempDir;
257
258    async fn find_network(store: &StateStore, id: Uuid) -> Option<Network> {
259        store
260            .handle
261            .for_read(move |reader| {
262                Network::find_id(&SqlUuid::new(id))
263                    .first::<Network>(reader)
264                    .optional()
265                    .map_err(HandleError::Query)
266            })
267            .await
268            .unwrap()
269    }
270
271    impl StateStore {
272        pub(crate) async fn network_opts(
273            &self,
274            network_id: Uuid,
275        ) -> Result<Vec<NetworkDriverOpts>> {
276            let network = self
277                .handle
278                .for_read(move |reader| {
279                    let network: Vec<NetworkDriverOpts> = network_driver_opts::table
280                        .filter(network_driver_opts::network_id.eq(SqlUuid::new(network_id)))
281                        .load(reader)?;
282
283                    Ok(network)
284                })
285                .await?;
286
287            Ok(network)
288        }
289    }
290
291    #[tokio::test]
292    async fn should_store() {
293        let tmp = TempDir::with_prefix("store_network").unwrap();
294        let db_file = tmp.path().join("state.db");
295        let db_file = db_file.to_str().unwrap();
296
297        let handle = db::Handle::open(db_file).await.unwrap();
298        let store = StateStore::new(handle);
299
300        let network_id = Uuid::new_v4();
301        let deployment_id = Uuid::new_v4();
302        let network = CreateNetwork {
303            id: ReqUuid(network_id),
304            deployment_id: ReqUuid(deployment_id),
305            driver: "bridge".to_string(),
306            internal: true,
307            enable_ipv6: false,
308            options: vec!["isolate=true".to_string()],
309        };
310        store.create_network(network).await.unwrap();
311
312        let res = find_network(&store, network_id).await.unwrap();
313
314        let exp = Network {
315            id: SqlUuid::new(network_id),
316            local_id: None,
317            status: NetworkStatus::Received,
318            driver: "bridge".to_string(),
319            internal: true,
320            enable_ipv6: false,
321        };
322
323        assert_eq!(res, exp);
324
325        let network_opts = store.network_opts(network_id).await.unwrap();
326
327        assert_eq!(
328            network_opts,
329            vec![NetworkDriverOpts {
330                network_id: SqlUuid::new(network_id),
331                name: "isolate".to_string(),
332                value: "true".to_string()
333            }]
334        );
335    }
336
337    #[tokio::test]
338    async fn should_update() {
339        let tmp = TempDir::with_prefix("update_network").unwrap();
340        let db_file = tmp.path().join("state.db");
341        let db_file = db_file.to_str().unwrap();
342
343        let handle = db::Handle::open(db_file).await.unwrap();
344        let store = StateStore::new(handle);
345
346        let network_id = Uuid::new_v4();
347        let deployment_id = Uuid::new_v4();
348        let network = CreateNetwork {
349            id: ReqUuid(network_id),
350            deployment_id: ReqUuid(deployment_id),
351            driver: "bridge".to_string(),
352            internal: true,
353            enable_ipv6: false,
354            options: vec!["isolate=true".to_string()],
355        };
356        store.create_network(network).await.unwrap();
357
358        store
359            .update_network_status(network_id, NetworkStatus::Published)
360            .await
361            .unwrap();
362
363        let res = find_network(&store, network_id).await.unwrap();
364
365        let exp = Network {
366            id: SqlUuid::new(network_id),
367            local_id: None,
368            status: NetworkStatus::Published,
369            driver: "bridge".to_string(),
370            internal: true,
371            enable_ipv6: false,
372        };
373
374        assert_eq!(res, exp);
375    }
376
377    #[tokio::test]
378    async fn find_network_by_local_id() {
379        let tmp = TempDir::with_prefix("find_network_by_local_id").unwrap();
380        let db_file = tmp.path().join("state.db");
381        let db_file = db_file.to_str().unwrap();
382
383        let handle = db::Handle::open(db_file).await.unwrap();
384        let store = StateStore::new(handle);
385
386        let network_id = Uuid::new_v4();
387        let deployment_id = Uuid::new_v4();
388        let local_id = Uuid::new_v4();
389        let network = CreateNetwork {
390            id: ReqUuid(network_id),
391            deployment_id: ReqUuid(deployment_id),
392            driver: "bridge".to_string(),
393            internal: true,
394            enable_ipv6: false,
395            options: vec!["isolate=true".to_string()],
396        };
397        store.create_network(network).await.unwrap();
398
399        store
400            .update_network_local_id(network_id, Some(local_id.to_string()))
401            .await
402            .unwrap();
403
404        let res = store
405            .find_network_by_local_id(local_id.to_string())
406            .await
407            .unwrap()
408            .unwrap();
409
410        assert_eq!(res, network_id);
411    }
412}