llmvm_core_lib/
service.rs

1use std::fs::create_dir;
2
3use async_stream::stream;
4use futures::StreamExt;
5use llmvm_protocol::{
6    error::ProtocolErrorType, service::BackendResponse, Core, GenerationRequest,
7    GenerationResponse, Message, NotificationStream, ProtocolError, ThreadInfo,
8};
9use llmvm_util::{get_file_path, DirType};
10use tracing::debug;
11
12use crate::{
13    error::CoreError,
14    threads::{
15        get_thread_infos, get_thread_messages, maybe_save_thread_messages_and_get_thread_id,
16    },
17    LLMVMCore, PROJECT_DIR_NAME,
18};
19
20#[llmvm_protocol::async_trait]
21impl Core for LLMVMCore {
22    async fn generate(
23        &self,
24        request: GenerationRequest,
25    ) -> std::result::Result<GenerationResponse, ProtocolError> {
26        async {
27            let (backend_request, model_description, thread_messages_to_save) =
28                self.prepare_for_generate(&request).await?;
29
30            let response = self
31                .send_generate_request(backend_request, &model_description)
32                .await?;
33
34            debug!("Response: {}", response.response);
35
36            let thread_id = maybe_save_thread_messages_and_get_thread_id(
37                &request,
38                response.response.clone(),
39                thread_messages_to_save,
40            )
41            .await?;
42
43            Ok(GenerationResponse {
44                response: response.response,
45                thread_id,
46            })
47        }
48        .await
49        .map_err(|e: CoreError| e.into())
50    }
51
52    async fn generate_stream(
53        &self,
54        request: GenerationRequest,
55    ) -> std::result::Result<NotificationStream<GenerationResponse>, ProtocolError> {
56        async {
57            let (backend_request, model_description, thread_messages_to_save) =
58                self.prepare_for_generate(&request).await?;
59
60            let mut stream = self
61                .send_generate_request_for_stream(backend_request, &model_description)
62                .await?;
63
64            Ok(stream! {
65                let mut full_response = String::new();
66                while let Some(result) = stream.next().await {
67                    match result {
68                        Ok(response) => match response {
69                            BackendResponse::GenerationStream(response) => {
70                                full_response.push_str(&response.response);
71                                yield Ok(GenerationResponse {
72                                    response: response.response,
73                                    thread_id: None
74                                });
75                            }
76                            _ => yield Err(CoreError::UnexpectedServiceResponse.into())
77                        },
78                        Err(e) => {
79                            yield Err(e);
80                        }
81                    }
82                }
83                if let Ok(thread_id) = maybe_save_thread_messages_and_get_thread_id(&request, full_response, thread_messages_to_save).await {
84                    yield Ok(GenerationResponse { response: String::new(), thread_id });
85                }
86            }.boxed())
87        }
88        .await
89        .map_err(|e: CoreError| e.into())
90    }
91
92    async fn get_last_thread_info(&self) -> std::result::Result<Option<ThreadInfo>, ProtocolError> {
93        async { Ok(get_thread_infos().await?.drain(0..1).next()) }
94            .await
95            .map_err(|e: CoreError| e.into())
96    }
97
98    async fn get_all_thread_infos(&self) -> std::result::Result<Vec<ThreadInfo>, ProtocolError> {
99        get_thread_infos().await.map_err(|e| e.into())
100    }
101
102    async fn get_thread_messages(
103        &self,
104        id: String,
105    ) -> std::result::Result<Vec<Message>, ProtocolError> {
106        get_thread_messages(&id).await.map_err(|e| e.into())
107    }
108
109    fn init_project(&self) -> std::result::Result<(), ProtocolError> {
110        create_dir(PROJECT_DIR_NAME).map_err(|error| ProtocolError {
111            error_type: ProtocolErrorType::Internal,
112            error: Box::new(error),
113        })?;
114        // Call the following util method for all dir types
115        // to trigger creation of project subdirectories
116        get_file_path(DirType::Prompts, "", true);
117        get_file_path(DirType::Presets, "", true);
118        get_file_path(DirType::Threads, "", true);
119        get_file_path(DirType::Logs, "", true);
120        get_file_path(DirType::Config, "", true);
121        get_file_path(DirType::Weights, "", true);
122        Ok(())
123    }
124}