milvus/
client.rs

1// Licensed to the LF AI & Data foundation under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9//     http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16
17use crate::collection::Collection;
18use crate::config::RPC_TIMEOUT;
19use crate::error::{Error, Result};
20use crate::options::CreateCollectionOptions;
21pub use crate::proto::common::ConsistencyLevel;
22use crate::proto::common::MsgType;
23use crate::proto::milvus::milvus_service_client::MilvusServiceClient;
24use crate::proto::milvus::{
25    CreateCollectionRequest, DescribeCollectionRequest, DropCollectionRequest, FlushRequest,
26    HasCollectionRequest, ShowCollectionsRequest,
27};
28use crate::schema::CollectionSchema;
29use crate::utils::{new_msg, status_to_result};
30use base64::Engine;
31use base64::engine::general_purpose;
32use prost::bytes::BytesMut;
33use prost::Message;
34use tonic::{Request};
35use tonic::service::Interceptor;
36use std::collections::HashMap;
37use std::convert::TryInto;
38use std::time::Duration;
39use tonic::codegen::{StdError, InterceptedService};
40use tonic::transport::Channel;
41
42#[derive(Clone)]
43pub struct AuthInterceptor {
44    token: Option<String>,
45}
46
47impl Interceptor for AuthInterceptor {
48    fn call(&mut self, mut req: Request<()>) -> std::result::Result<tonic::Request<()>, tonic::Status> {
49        if let Some(ref token) = self.token {
50            let header_value = format!("{}", token);
51            req.metadata_mut()
52                .insert("authorization", header_value.parse().unwrap());
53        }
54
55        Ok(req)
56    }
57}
58
59#[derive(Clone)]
60pub struct ClientBuilder<D> {
61    dst: D,
62    username: Option<String>,
63    password: Option<String>,
64}
65
66impl<D> ClientBuilder<D>
67where
68    D: TryInto<tonic::transport::Endpoint> + Clone,
69    D::Error: Into<StdError>,
70    D::Error: std::fmt::Debug,
71{
72    pub fn new(dst: D) -> Self {
73        Self {
74            dst,
75            username: None,
76            password: None,
77        }
78    }
79
80    pub fn username(mut self, username: &str) -> Self {
81        self.username = Some(username.to_owned());
82        self
83    }
84
85    pub fn password(mut self, password: &str) -> Self {
86        self.password = Some(password.to_owned());
87        self
88    }
89
90    pub async fn build(self) -> Result<Client> {
91        Client::with_timeout(self.dst, RPC_TIMEOUT, self.username, self.password).await
92    }
93}
94
95#[derive(Clone)]
96pub struct Client {
97    client: MilvusServiceClient<InterceptedService<Channel, AuthInterceptor>>,
98}
99
100impl Client {
101    pub async fn new<D>(dst: D) -> Result<Self>
102    where
103        D: TryInto<tonic::transport::Endpoint>,
104        D::Error: Into<StdError>,
105        D::Error: std::fmt::Debug,
106    {
107        Self::with_timeout(dst, RPC_TIMEOUT, None, None).await
108    }
109
110    pub async fn with_timeout<D>(
111        dst: D,
112        timeout: Duration,
113        username: Option<String>,
114        password: Option<String>,
115    ) -> Result<Self>
116    where
117        D: TryInto<tonic::transport::Endpoint>,
118        D::Error: Into<StdError>,
119        D::Error: std::fmt::Debug,
120    {
121        let mut dst: tonic::transport::Endpoint = dst.try_into().map_err(|err| {
122            Error::InvalidParameter("url".to_owned(), format!("to parse {:?}", err))
123        })?;
124
125        dst = dst.timeout(timeout);
126
127        let token = match (username, password) {
128            (Some(username), Some(password)) => {
129                let auth_token = format!("{}:{}", username, password);
130                let auth_token = general_purpose::STANDARD.encode(auth_token);
131                Some(auth_token)
132            }
133            _ => None,
134        };
135
136        let auth_interceptor = AuthInterceptor { token };
137
138        let conn = tonic::transport::Endpoint::new(dst)?.connect().await?;
139
140        let client = MilvusServiceClient::with_interceptor(conn, auth_interceptor);
141
142        Ok(Self { client })
143    }
144
145    pub async fn create_collection(
146        &self,
147        schema: CollectionSchema,
148        options: Option<CreateCollectionOptions>,
149    ) -> Result<Collection> {
150        let options = options.unwrap_or_default();
151        let schema: crate::proto::schema::CollectionSchema = schema.into();
152        let mut buf = BytesMut::new();
153
154        schema.encode(&mut buf)?;
155
156        let status = self
157            .client
158            .clone()
159            .create_collection(CreateCollectionRequest {
160                base: Some(new_msg(MsgType::CreateCollection)),
161                collection_name: schema.name.to_string(),
162                schema: buf.to_vec(),
163                shards_num: options.shard_num,
164                consistency_level: options.consistency_level as i32,
165                ..Default::default()
166            })
167            .await?
168            .into_inner();
169
170        status_to_result(&Some(status))?;
171
172        Ok(self.get_collection(&schema.name).await?)
173    }
174
175    pub async fn get_collection(&self, collection_name: &str) -> Result<Collection> {
176        let resp = self
177            .client
178            .clone()
179            .describe_collection(DescribeCollectionRequest {
180                base: Some(new_msg(MsgType::DescribeCollection)),
181                db_name: "".to_owned(),
182                collection_name: collection_name.to_owned(),
183                collection_id: 0,
184                time_stamp: 0,
185            })
186            .await?
187            .into_inner();
188
189        status_to_result(&resp.status)?;
190
191        Ok(Collection::new(self.client.clone(), resp))
192    }
193
194    pub async fn has_collection<S>(&self, name: S) -> Result<bool>
195    where
196        S: Into<String>,
197    {
198        let name = name.into();
199        let res = self
200            .client
201            .clone()
202            .has_collection(HasCollectionRequest {
203                base: Some(new_msg(MsgType::HasCollection)),
204                db_name: "".to_string(),
205                collection_name: name.clone(),
206                time_stamp: 0,
207            })
208            .await?
209            .into_inner();
210
211        status_to_result(&res.status)?;
212
213        Ok(res.value)
214    }
215
216    pub async fn drop_collection<S>(&self, name: S) -> Result<()>
217    where
218        S: Into<String>,
219    {
220        status_to_result(&Some(
221            self.client
222                .clone()
223                .drop_collection(DropCollectionRequest {
224                    base: Some(new_msg(MsgType::DropCollection)),
225                    collection_name: name.into(),
226                    ..Default::default()
227                })
228                .await?
229                .into_inner(),
230        ))
231    }
232
233    pub async fn list_collections(&self) -> Result<Vec<String>> {
234        let response = self
235            .client
236            .clone()
237            .show_collections(ShowCollectionsRequest {
238                base: Some(new_msg(MsgType::ShowCollections)),
239                ..Default::default()
240            })
241            .await?
242            .into_inner();
243
244        status_to_result(&response.status)?;
245        Ok(response.collection_names)
246    }
247
248    pub async fn flush_collections<C>(&self, collections: C) -> Result<HashMap<String, Vec<i64>>>
249    where
250        C: IntoIterator,
251        C::Item: ToString,
252    {
253        let res = self
254            .client
255            .clone()
256            .flush(FlushRequest {
257                base: Some(new_msg(MsgType::Flush)),
258                db_name: "".to_string(),
259                collection_names: collections.into_iter().map(|x| x.to_string()).collect(),
260            })
261            .await?
262            .into_inner();
263
264        status_to_result(&res.status)?;
265
266        Ok(res
267            .coll_seg_i_ds
268            .into_iter()
269            .map(|(k, v)| (k, v.data))
270            .collect())
271    }
272}