orign/
query.rs

1// src/query.rs
2use crate::entities::buffer;
3use crate::entities::feedback;
4use crate::entities::human;
5use crate::entities::llm;
6use crate::entities::model_deployment;
7use crate::entities::training_job;
8use sea_orm::*;
9use sea_orm::{ColumnTrait, DatabaseConnection, DbErr, EntityTrait, QueryFilter};
10
11pub struct Query;
12
13impl Query {
14    pub async fn find_deployment_by_id(
15        db: &DatabaseConnection,
16        id: &str,
17    ) -> Result<Option<model_deployment::Model>, DbErr> {
18        model_deployment::Entity::find_by_id(id).one(db).await
19    }
20
21    pub async fn find_deployments_by_owner(
22        db: &DatabaseConnection,
23        owner_id: &str,
24    ) -> Result<Vec<model_deployment::Model>, DbErr> {
25        model_deployment::Entity::find()
26            .filter(model_deployment::Column::OwnerId.eq(owner_id))
27            .all(db)
28            .await
29    }
30
31    pub async fn find_deployments_by_framework(
32        db: &DatabaseConnection,
33        framework: &str,
34    ) -> Result<Vec<model_deployment::Model>, DbErr> {
35        model_deployment::Entity::find()
36            .filter(model_deployment::Column::Framework.eq(framework))
37            .all(db)
38            .await
39    }
40
41    pub async fn find_all_deployments(
42        db: &DatabaseConnection,
43    ) -> Result<Vec<model_deployment::Model>, DbErr> {
44        model_deployment::Entity::find().all(db).await
45    }
46
47    pub async fn find_deployments_by_kind(
48        db: &DatabaseConnection,
49        kind: &str,
50    ) -> Result<Vec<model_deployment::Model>, DbErr> {
51        model_deployment::Entity::find()
52            .filter(model_deployment::Column::Kind.eq(kind))
53            .all(db)
54            .await
55    }
56
57    // Helper function to find matching deployments based on model and provider criteria
58    pub async fn find_matching_deployments(
59        db: &DatabaseConnection,
60        model: Option<&str>,
61        framework: Option<&str>,
62        kind: &str, // kind is now required
63    ) -> Result<Vec<model_deployment::Model>, DbErr> {
64        // Start with kind filter as it's always required
65        let mut query =
66            model_deployment::Entity::find().filter(model_deployment::Column::Kind.eq(kind));
67
68        // Add provider filter if specified
69        if let Some(framework) = framework {
70            query = query.filter(model_deployment::Column::Framework.eq(framework));
71        }
72
73        // Execute the query
74        let deployments = query.all(db).await?;
75
76        // If model is specified, filter deployments by checking model field and params
77        if let Some(model_name) = model {
78            Ok(deployments
79                .into_iter()
80                .filter(|deployment| {
81                    // Check the model field first
82                    if deployment.model == model_name {
83                        return true;
84                    }
85
86                    // Then check the params if they exist
87                    if let Some(params) = &deployment.params {
88                        if let Some(param_model) = params.get("model") {
89                            if let Some(model_str) = param_model.as_str() {
90                                return model_str == model_name;
91                            }
92                        }
93                    }
94                    false
95                })
96                .collect())
97        } else {
98            Ok(deployments)
99        }
100    }
101
102    pub async fn find_training_job_by_id(
103        db: &DatabaseConnection,
104        id: &str,
105    ) -> Result<Option<training_job::Model>, DbErr> {
106        training_job::Entity::find_by_id(id).one(db).await
107    }
108
109    pub async fn find_training_jobs_by_owners(
110        db: &DatabaseConnection,
111        owner_ids: &[&str],
112    ) -> Result<Vec<training_job::Model>, DbErr> {
113        training_job::Entity::find()
114            .filter(training_job::Column::OwnerId.is_in(owner_ids.iter().copied()))
115            .all(db)
116            .await
117    }
118
119    pub async fn find_training_jobs_by_queue(
120        db: &DatabaseConnection,
121        queue_name: &str,
122    ) -> Result<Vec<training_job::Model>, DbErr> {
123        training_job::Entity::find()
124            .filter(training_job::Column::Queue.eq(queue_name))
125            .all(db)
126            .await
127    }
128
129    pub async fn find_training_jobs_by_owner(
130        db: &DatabaseConnection,
131        owner_id: &str,
132    ) -> Result<Vec<training_job::Model>, DbErr> {
133        training_job::Entity::find()
134            .filter(training_job::Column::OwnerId.eq(owner_id))
135            .all(db)
136            .await
137    }
138
139    pub async fn find_training_job_by_id_and_owners(
140        db: &DatabaseConnection,
141        id: &str,
142        owner_ids: &[&str],
143    ) -> Result<training_job::Model, DbErr> {
144        training_job::Entity::find()
145            .filter(
146                training_job::Column::Id
147                    .eq(id.to_owned())
148                    .and(training_job::Column::OwnerId.is_in(owner_ids.iter().copied())),
149            )
150            .one(db)
151            .await?
152            .ok_or_else(|| DbErr::RecordNotFound(format!("No training job found with id {}", id)))
153    }
154
155    pub async fn find_training_job_by_id_and_owner(
156        db: &DatabaseConnection,
157        id: &str,
158        owner_id: &str,
159    ) -> Result<training_job::Model, DbErr> {
160        training_job::Entity::find()
161            .filter(
162                training_job::Column::Id
163                    .eq(id.to_owned())
164                    .and(training_job::Column::OwnerId.eq(owner_id.to_owned())),
165            )
166            .one(db)
167            .await?
168            .ok_or_else(|| DbErr::RecordNotFound(format!("No training job found with id {}", id)))
169    }
170
171    pub async fn find_training_job_by_resource_name(
172        db: &DatabaseConnection,
173        resource_name: &str,
174    ) -> Result<Option<training_job::Model>, DbErr> {
175        training_job::Entity::find()
176            .filter(training_job::Column::ResourceName.eq(resource_name))
177            .one(db)
178            .await
179    }
180
181    pub async fn find_buffer_by_id(
182        db: &DatabaseConnection,
183        id: &str,
184    ) -> Result<Option<buffer::Model>, DbErr> {
185        buffer::Entity::find_by_id(id).one(db).await
186    }
187
188    pub async fn find_buffer_by_name_and_owners(
189        db: &DatabaseConnection,
190        name: &str,
191        namespace: &str,
192        owner_ids: &[&str],
193    ) -> Result<Option<buffer::Model>, DbErr> {
194        buffer::Entity::find()
195            .filter(buffer::Column::Name.eq(name))
196            .filter(buffer::Column::Namespace.eq(namespace))
197            .filter(buffer::Column::OwnerId.is_in(owner_ids.iter().copied()))
198            .one(db)
199            .await
200    }
201
202    pub async fn find_buffer_by_name(
203        db: &DatabaseConnection,
204        namespace: &str,
205        name: &str,
206    ) -> Result<Option<buffer::Model>, DbErr> {
207        buffer::Entity::find()
208            .filter(buffer::Column::Name.eq(name))
209            .filter(buffer::Column::Namespace.eq(namespace))
210            .one(db)
211            .await
212    }
213
214    pub async fn find_buffers_by_owner(
215        db: &DatabaseConnection,
216        owner_id: &str,
217    ) -> Result<Vec<buffer::Model>, DbErr> {
218        buffer::Entity::find()
219            .filter(buffer::Column::OwnerId.eq(owner_id))
220            .all(db)
221            .await
222    }
223
224    pub async fn find_buffers_by_owners(
225        db: &DatabaseConnection,
226        owner_ids: &[&str],
227    ) -> Result<Vec<buffer::Model>, DbErr> {
228        buffer::Entity::find()
229            .filter(buffer::Column::OwnerId.is_in(owner_ids.iter().cloned()))
230            .all(db)
231            .await
232    }
233
234    pub async fn find_buffer_by_id_and_owners(
235        db: &DatabaseConnection,
236        id: &str,
237        owner_ids: &[&str],
238    ) -> Result<Option<buffer::Model>, DbErr> {
239        buffer::Entity::find()
240            .filter(buffer::Column::Id.eq(id))
241            .filter(buffer::Column::OwnerId.is_in(owner_ids.iter().copied()))
242            .one(db)
243            .await
244    }
245
246    pub async fn find_llms_by_owners(
247        db: &DatabaseConnection,
248        owner_ids: &[&str],
249    ) -> Result<Vec<llm::Model>, DbErr> {
250        llm::Entity::find()
251            .filter(llm::Column::OwnerId.is_in(owner_ids.iter().copied()))
252            .all(db)
253            .await
254    }
255
256    // Returns one LLM record that matches name + namespace (or `None` if not found).
257    pub async fn find_llm_by_name_and_namespace(
258        db: &DatabaseConnection,
259        name: &str,
260        namespace: &str,
261    ) -> Result<Option<llm::Model>, DbErr> {
262        llm::Entity::find()
263            .filter(llm::Column::Name.eq(name))
264            .filter(llm::Column::Namespace.eq(namespace))
265            .one(db)
266            .await
267    }
268
269    pub async fn find_llm_by_name_and_namespace_and_owners(
270        db: &DatabaseConnection,
271        name: &str,
272        namespace: &str,
273        owner_ids: &[&str],
274    ) -> Result<Option<llm::Model>, DbErr> {
275        llm::Entity::find()
276            .filter(llm::Column::Name.eq(name))
277            .filter(llm::Column::Namespace.eq(namespace))
278            .filter(llm::Column::OwnerId.is_in(owner_ids.iter().copied()))
279            .one(db)
280            .await
281    }
282
283    pub async fn find_humans_by_owners(
284        db: &DatabaseConnection,
285        owner_ids: &[&str],
286    ) -> Result<Vec<human::Model>, DbErr> {
287        human::Entity::find()
288            .filter(human::Column::OwnerId.is_in(owner_ids.iter().copied()))
289            .all(db)
290            .await
291    }
292
293    pub async fn find_human_by_name_and_namespace_and_owners(
294        db: &DatabaseConnection,
295        name: &str,
296        namespace: &str,
297        owner_ids: &[&str],
298    ) -> Result<Option<human::Model>, DbErr> {
299        human::Entity::find()
300            .filter(human::Column::Name.eq(name))
301            .filter(human::Column::Namespace.eq(namespace))
302            .filter(human::Column::OwnerId.is_in(owner_ids.iter().copied()))
303            .one(db)
304            .await
305    }
306    pub async fn find_feedback_by_id_and_human_id(
307        db: &DatabaseConnection,
308        feedback_id: &str,
309        human_id: &str,
310    ) -> Result<Option<feedback::Model>, DbErr> {
311        feedback::Entity::find_by_id(feedback_id)
312            .filter(feedback::Column::HumanId.eq(human_id))
313            .one(db)
314            .await
315    }
316}