Skip to main content

couchbase_core/mgmtx/
mgmt.rs

1/*
2 *
3 *  * Copyright (c) 2025 Couchbase, Inc.
4 *  *
5 *  * Licensed under the Apache License, Version 2.0 (the "License");
6 *  * you may not use this file except in compliance with the License.
7 *  * You may obtain a copy of the License at
8 *  *
9 *  *    http://www.apache.org/licenses/LICENSE-2.0
10 *  *
11 *  * Unless required by applicable law or agreed to in writing, software
12 *  * distributed under the License is distributed on an "AS IS" BASIS,
13 *  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 *  * See the License for the specific language governing permissions and
15 *  * limitations under the License.
16 *
17 */
18
19use crate::cbconfig::{FullBucketConfig, FullClusterConfig, TerseConfig};
20use crate::httpx::client::Client;
21use crate::httpx::request::{Auth, OnBehalfOfInfo, Request};
22use crate::httpx::response::Response;
23use crate::mgmtx::error;
24use crate::mgmtx::mgmt_query::IndexStatus;
25use crate::mgmtx::options::{
26    GetAutoFailoverSettingsOptions, GetBucketStatsOptions, GetFullBucketConfigOptions,
27    GetFullClusterConfigOptions, GetTerseBucketConfigOptions, GetTerseClusterConfigOptions,
28    IndexStatusOptions, LoadSampleBucketOptions,
29};
30use crate::tracingcomponent::TracingComponent;
31use bytes::Bytes;
32use http::Method;
33use serde::de::DeserializeOwned;
34use serde::{Deserialize, Deserializer};
35use serde_json::value::RawValue;
36use std::collections::HashMap;
37use std::sync::Arc;
38use std::time::Duration;
39
40lazy_static! {
41    static ref FIELD_NAME_MAP: HashMap<String, String> = {
42        HashMap::from([
43            (
44                "durability_min_level".to_string(),
45                "DurabilityMinLevel".to_string(),
46            ),
47            ("ramquota".to_string(), "RamQuotaMB".to_string()),
48            ("replicanumber".to_string(), "ReplicaNumber".to_string()),
49            ("maxttl".to_string(), "MaxTTL".to_string()),
50            ("history".to_string(), "HistoryEnabled".to_string()),
51            ("numvbuckets".to_string(), "numVBuckets".to_string()),
52        ])
53    };
54}
55
56#[derive(Debug)]
57pub struct Management<C: Client> {
58    pub http_client: Arc<C>,
59    pub user_agent: String,
60    pub endpoint: String,
61    pub canonical_endpoint: String,
62    pub auth: Auth,
63    pub(crate) tracing: Arc<TracingComponent>,
64}
65
66impl<C: Client> Management<C> {
67    pub fn new_request(
68        &self,
69        method: Method,
70        path: impl Into<String>,
71        content_type: impl Into<String>,
72        on_behalf_of: Option<OnBehalfOfInfo>,
73        headers: Option<HashMap<&str, &str>>,
74        body: Option<Bytes>,
75    ) -> Request {
76        let auth = if let Some(obo) = on_behalf_of {
77            Auth::OnBehalfOf(OnBehalfOfInfo {
78                username: obo.username,
79                password_or_domain: obo.password_or_domain,
80            })
81        } else {
82            self.auth.clone()
83        };
84
85        let mut req = Request::new(method, format!("{}/{}", self.endpoint, path.into()))
86            .auth(auth)
87            .user_agent(self.user_agent.clone())
88            .content_type(content_type.into())
89            .body(body);
90
91        if let Some(headers) = headers {
92            for (key, value) in headers.into_iter() {
93                req = req.add_header(key, value);
94            }
95        }
96
97        req
98    }
99
100    pub async fn execute(
101        &self,
102        method: Method,
103        path: impl Into<String>,
104        content_type: impl Into<String>,
105        on_behalf_of: Option<OnBehalfOfInfo>,
106        headers: Option<HashMap<&str, &str>>,
107        body: Option<Bytes>,
108    ) -> error::Result<Response> {
109        let req = self.new_request(method, path, content_type, on_behalf_of, headers, body);
110
111        self.http_client
112            .execute(req)
113            .await
114            .map_err(error::Error::from)
115    }
116
117    pub(crate) async fn decode_common_error(
118        method: Method,
119        path: String,
120        feature: impl Into<String>,
121        response: Response,
122    ) -> error::Error {
123        let status = response.status();
124        let url = response.url().to_string();
125        let body = match response.bytes().await {
126            Ok(b) => b,
127            Err(e) => {
128                return error::Error::new_message_error(format!(
129                    "could not parse response body: {e}"
130                ))
131            }
132        };
133
134        let body_str = match String::from_utf8(body.to_vec()) {
135            Ok(s) => s,
136            Err(e) => {
137                return error::Error::new_message_error(format!(
138                    "could not parse error response: {e}"
139                ))
140            }
141        };
142
143        let lower_body_str = body_str.to_lowercase();
144
145        let kind = if lower_body_str.contains("not found") && lower_body_str.contains("collection")
146        {
147            error::ServerErrorKind::CollectionNotFound
148        } else if lower_body_str.contains("not found") && lower_body_str.contains("scope") {
149            error::ServerErrorKind::ScopeNotFound
150        } else if lower_body_str.contains("not found") && lower_body_str.contains("bucket") {
151            error::ServerErrorKind::BucketNotFound
152        } else if (lower_body_str.contains("not found") && lower_body_str.contains("user"))
153            || lower_body_str.contains("unknown user")
154        {
155            error::ServerErrorKind::UserNotFound
156        } else if (lower_body_str.contains("not found") && lower_body_str.contains("group"))
157            || lower_body_str.contains("unknown group")
158        {
159            error::ServerErrorKind::GroupNotFound
160        } else if lower_body_str.contains("already exists") && lower_body_str.contains("collection")
161        {
162            error::ServerErrorKind::CollectionExists
163        } else if lower_body_str.contains("already exists") && lower_body_str.contains("scope") {
164            error::ServerErrorKind::ScopeExists
165        } else if lower_body_str.contains("already exists") && lower_body_str.contains("bucket") {
166            error::ServerErrorKind::BucketExists
167        } else if lower_body_str.contains("flush is disabled") {
168            error::ServerErrorKind::FlushDisabled
169        } else if lower_body_str.contains("requested resource not found")
170            || lower_body_str.contains("non existent bucket")
171        {
172            error::ServerErrorKind::BucketNotFound
173        } else if lower_body_str.contains("not yet complete, but will continue") {
174            error::ServerErrorKind::OperationDelayed
175        } else if status == 400 {
176            let s_err = Self::parse_for_invalid_arg(&lower_body_str);
177            if let Some(ia) = s_err {
178                let key = ia.0;
179                if FIELD_NAME_MAP.contains_key(&key) {
180                    error::ServerErrorKind::ServerInvalidArg {
181                        arg: key,
182                        reason: ia.1,
183                    }
184                } else {
185                    error::ServerErrorKind::Unknown
186                }
187            } else if lower_body_str.contains("not allowed on this type of bucket") {
188                error::ServerErrorKind::ServerInvalidArg {
189                    arg: "historyEnabled".to_string(),
190                    reason: body_str.to_string(),
191                }
192            } else if lower_body_str.contains("already loaded") {
193                error::ServerErrorKind::SampleAlreadyLoaded
194            } else if lower_body_str.contains("not a valid sample") {
195                error::ServerErrorKind::InvalidSampleBucket
196            } else {
197                error::ServerErrorKind::Unknown
198            }
199        } else if status == 404 {
200            error::ServerErrorKind::UnsupportedFeature {
201                feature: feature.into(),
202            }
203        } else if status == 401 {
204            error::ServerErrorKind::AccessDenied
205        } else {
206            error::ServerErrorKind::Unknown
207        };
208
209        error::ServerError::new(status, url, method, path, body_str, kind).into()
210    }
211
212    fn parse_for_invalid_arg(body: &str) -> Option<(String, String)> {
213        let inv_arg: ServerErrors = match serde_json::from_str(body) {
214            Ok(i) => i,
215            Err(_e) => {
216                return None;
217            }
218        };
219
220        if let Some((k, v)) = inv_arg.errors.into_iter().next() {
221            return Some((k, v));
222        }
223
224        None
225    }
226
227    pub async fn get_terse_cluster_config(
228        &self,
229        opts: &GetTerseClusterConfigOptions<'_>,
230    ) -> error::Result<TerseConfig> {
231        let method = Method::GET;
232        let path = "pools/default/nodeServices".to_string();
233
234        let resp = self
235            .execute(
236                method.clone(),
237                &path,
238                "",
239                opts.on_behalf_of_info.cloned(),
240                None,
241                None,
242            )
243            .await?;
244
245        if resp.status() != 200 {
246            return Err(
247                Self::decode_common_error(method, path, "get_terse_cluster_config", resp).await,
248            );
249        }
250
251        parse_response_json(resp).await
252    }
253
254    pub async fn get_full_cluster_config(
255        &self,
256        opts: &GetFullClusterConfigOptions<'_>,
257    ) -> error::Result<FullClusterConfig> {
258        let method = Method::GET;
259        let path = "pools/default".to_string();
260
261        let resp = self
262            .execute(
263                method.clone(),
264                &path,
265                "",
266                opts.on_behalf_of_info.cloned(),
267                None,
268                None,
269            )
270            .await?;
271
272        if resp.status() != 200 {
273            return Err(
274                Self::decode_common_error(method, path, "get_full_cluster_config", resp).await,
275            );
276        }
277
278        parse_response_json(resp).await
279    }
280
281    pub async fn get_terse_bucket_config(
282        &self,
283        opts: &GetTerseBucketConfigOptions<'_>,
284    ) -> error::Result<TerseConfig> {
285        let method = Method::GET;
286        let path = format!("pools/default/b/{}", opts.bucket_name);
287
288        let resp = self
289            .execute(
290                method.clone(),
291                &path,
292                "",
293                opts.on_behalf_of_info.cloned(),
294                None,
295                None,
296            )
297            .await?;
298
299        if resp.status() != 200 {
300            return Err(
301                Self::decode_common_error(method, path, "get_terse_bucket_config", resp).await,
302            );
303        }
304
305        parse_response_json(resp).await
306    }
307
308    pub async fn get_full_bucket_config(
309        &self,
310        opts: &GetFullBucketConfigOptions<'_>,
311    ) -> error::Result<FullBucketConfig> {
312        let method = Method::GET;
313        let path = format!("pools/default/buckets/{}", opts.bucket_name);
314
315        let resp = self
316            .execute(
317                method.clone(),
318                &path,
319                "",
320                opts.on_behalf_of_info.cloned(),
321                None,
322                None,
323            )
324            .await?;
325
326        if resp.status() != 200 {
327            return Err(
328                Self::decode_common_error(method, path, "get_full_bucket_config", resp).await,
329            );
330        }
331
332        parse_response_json(resp).await
333    }
334
335    pub async fn load_sample_bucket(
336        &self,
337        opts: &LoadSampleBucketOptions<'_>,
338    ) -> error::Result<()> {
339        let method = Method::POST;
340        let path = "sampleBuckets/install";
341        let body = Bytes::from(opts.bucket_name.to_string());
342
343        let resp = self
344            .execute(
345                method.clone(),
346                path,
347                "application/json",
348                opts.on_behalf_of_info.cloned(),
349                None,
350                Some(body),
351            )
352            .await?;
353
354        if resp.status() != 202 {
355            return Err(Self::decode_common_error(
356                method,
357                path.to_string(),
358                "load_sample_bucket",
359                resp,
360            )
361            .await);
362        }
363
364        Ok(())
365    }
366
367    pub async fn index_status(&self, opts: &IndexStatusOptions<'_>) -> error::Result<IndexStatus> {
368        let method = Method::GET;
369        let path = "indexStatus";
370
371        let resp = self
372            .execute(
373                method.clone(),
374                path,
375                "application/json",
376                opts.on_behalf_of_info.cloned(),
377                None,
378                None,
379            )
380            .await?;
381
382        if resp.status() != 200 {
383            return Err(
384                Self::decode_common_error(method, path.to_string(), "index_status", resp).await,
385            );
386        }
387
388        parse_response_json(resp).await
389    }
390
391    pub async fn get_auto_failover_settings(
392        &self,
393        opts: &GetAutoFailoverSettingsOptions<'_>,
394    ) -> error::Result<AutoFailoverSettings> {
395        let method = Method::GET;
396        let path = "settings/autoFailover";
397
398        let resp = self
399            .execute(
400                method.clone(),
401                path,
402                "",
403                opts.on_behalf_of_info.cloned(),
404                None,
405                None,
406            )
407            .await?;
408
409        if resp.status() != 200 {
410            return Err(Self::decode_common_error(
411                method,
412                path.to_string(),
413                "get_autofailover_settings",
414                resp,
415            )
416            .await);
417        }
418
419        parse_response_json(resp).await
420    }
421
422    pub async fn get_bucket_stats(
423        &self,
424        opts: &GetBucketStatsOptions<'_>,
425    ) -> error::Result<Box<RawValue>> {
426        let method = Method::GET;
427        let path = format!("pools/default/buckets/{}/stats", opts.bucket_name);
428
429        let resp = self
430            .execute(
431                method.clone(),
432                &path,
433                "",
434                opts.on_behalf_of_info.cloned(),
435                None,
436                None,
437            )
438            .await?;
439
440        if resp.status() != 200 {
441            return Err(Self::decode_common_error(method, path, "get_bucket_stats", resp).await);
442        }
443
444        parse_response_json(resp).await
445    }
446}
447
448pub(crate) async fn parse_response_json<T: DeserializeOwned>(resp: Response) -> error::Result<T> {
449    let body = resp
450        .bytes()
451        .await
452        .map_err(|e| error::Error::new_message_error(format!("could not read response: {e}")))?;
453
454    serde_json::from_slice(&body)
455        .map_err(|e| error::Error::new_message_error(format!("could not parse response: {e}")))
456}
457
458#[derive(Deserialize)]
459struct ServerErrors {
460    errors: HashMap<String, String>,
461}
462
463#[derive(Debug, Deserialize)]
464pub struct AutoFailoverSettings {
465    pub enabled: bool,
466    #[serde(deserialize_with = "deserialize_duration_secs")]
467    pub timeout: Duration,
468    pub count: usize,
469    #[serde(rename = "failoverOnDataDiskIssues")]
470    pub failover_on_data_disk_issues: FailoverOnDataDiskIssues,
471    #[serde(rename = "maxCount")]
472    pub max_count: usize,
473    pub can_abort_rebalance: bool,
474    #[serde(rename = "failoverPreserveDurabilityMajority")]
475    pub failover_preserve_durability_majority: Option<bool>,
476    #[serde(rename = "failoverOnDataDiskNonResponsiveness")]
477    pub failover_on_data_disk_non_responsiveness: Option<bool>,
478    #[serde(rename = "allowFailoverEphemeralNoReplicas")]
479    pub allow_failover_ephemeral_no_replicas: Option<bool>,
480}
481
482#[derive(Debug, Deserialize)]
483pub struct FailoverOnDataDiskIssues {
484    pub enabled: bool,
485    #[serde(rename = "timePeriod", deserialize_with = "deserialize_duration_secs")]
486    pub time_period: Duration,
487}
488
489#[derive(Debug, Deserialize)]
490pub struct FailoverOnDataDiskNonResponsiveness {
491    pub enabled: bool,
492    #[serde(rename = "timePeriod", deserialize_with = "deserialize_duration_secs")]
493    pub time_period: Duration,
494}
495
496fn deserialize_duration_secs<'de, D>(deserializer: D) -> Result<Duration, D::Error>
497where
498    D: Deserializer<'de>,
499{
500    let secs: u64 = Deserialize::deserialize(deserializer)?;
501    Ok(Duration::from_secs(secs))
502}