dynamo-llm 0.2.1

Dynamo LLM Library
Documentation
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use super::*;

use anyhow::Result;
use nixl_sys::{MemoryRegion, NixlDescriptor, OptArgs, XferDescList, XferOp};
use std::ops::Range;

/// Copy a block from a source to a destination using CUDA memcpy
pub fn write_block_to<'a, Source, Destination>(
    src: &'a Source,
    dst: &'a mut Destination,
    ctx: &TransferContext,
    notify: Option<String>,
) -> Result<()>
where
    Source: BlockDataProvider,
    Destination: BlockDataProviderMut,
{
    let src_data = src.block_data(private::PrivateToken);
    let dst_data = dst.block_data_mut(private::PrivateToken);

    if src_data.is_fully_contiguous() && dst_data.is_fully_contiguous() {
        let nixl_agent = ctx.nixl_agent().expect("NIXL agent not found");
        let remote_worker_id = dst_data.worker_id.to_string();

        let mut src_dl = XferDescList::new(src_data.storage_type().nixl_mem_type())?;
        let mut dst_dl = XferDescList::new(dst_data.storage_type().nixl_mem_type())?;

        let src_desc = src_data.block_view()?.as_nixl_descriptor();
        let dst_desc = dst_data.block_view_mut()?.as_nixl_descriptor_mut();

        unsafe {
            src_dl.add_desc(
                src_desc.as_ptr() as usize,
                src_desc.size(),
                src_desc.device_id(),
            )?;

            dst_dl.add_desc(
                dst_desc.as_ptr() as usize,
                dst_desc.size(),
                dst_desc.device_id(),
            )?;
        }

        let xfer_req =
            nixl_agent.create_xfer_req(XferOp::Write, &src_dl, &dst_dl, &remote_worker_id, None)?;

        let mut xfer_args = OptArgs::new()?;

        if let Some(notify) = notify {
            xfer_args.set_has_notification(true)?;
            xfer_args.set_notification_message(notify.as_bytes())?;
        }

        let mut status = nixl_agent.post_xfer_req(&xfer_req, Some(&xfer_args))?;

        tracing::span!(tracing::Level::DEBUG, "Waiting for transfer to complete").in_scope(|| {
            while status {
                status = nixl_agent.get_xfer_status(&xfer_req).unwrap();
            }
        });
    } else {
        assert_eq!(src_data.num_layers(), dst_data.num_layers());
        write_layers_to(0..src_data.num_layers(), src, dst, ctx, notify)?;
    }
    Ok(())
}

/// Copy a range of layers from a source to a destination using CUDA memcpy
pub fn write_layers_to<'a, Source, Destination>(
    layer_range: Range<usize>,
    src: &'a Source,
    dst: &'a mut Destination,
    ctx: &TransferContext,
    notify: Option<String>,
) -> Result<()>
where
    Source: BlockDataProvider,
    Destination: BlockDataProviderMut,
{
    let src_data = src.block_data(private::PrivateToken);
    let dst_data = dst.block_data_mut(private::PrivateToken);

    let nixl_agent = ctx.nixl_agent().expect("NIXL agent not found");
    let remote_worker_id = dst_data.worker_id.to_string();

    let mut src_dl = XferDescList::new(src_data.storage_type().nixl_mem_type())?;
    let mut dst_dl = XferDescList::new(dst_data.storage_type().nixl_mem_type())?;

    // #[cfg(debug_assertions)]
    // {
    //     let expected_strategy = <<Source as BlockDataProvider>::StorageType as WriteToStrategy<
    //         Destination::StorageType,
    //     >>::write_to_strategy();
    //     assert_eq!(strategy, expected_strategy);
    // }

    for layer_idx in layer_range {
        let src_view = src_data.layer_view(layer_idx)?;
        let mut dst_view = dst_data.layer_view_mut(layer_idx)?;

        debug_assert_eq!(src_view.size(), dst_view.size());

        let src_desc = src_view.as_nixl_descriptor();
        let dst_desc = dst_view.as_nixl_descriptor_mut();

        unsafe {
            src_dl.add_desc(
                src_desc.as_ptr() as usize,
                src_desc.size(),
                src_desc.device_id(),
            )?;

            dst_dl.add_desc(
                dst_desc.as_ptr() as usize,
                dst_desc.size(),
                dst_desc.device_id(),
            )?;
        }
    }

    let mut xfer_args = OptArgs::new()?;

    if let Some(notify) = notify {
        xfer_args.set_has_notification(true)?;
        xfer_args.set_notification_message(notify.as_bytes())?;
    }

    let xfer_req = nixl_agent.create_xfer_req(
        XferOp::Write,
        &src_dl,
        &dst_dl,
        &remote_worker_id,
        Some(&xfer_args),
    )?;

    let mut status = nixl_agent.post_xfer_req(&xfer_req, Some(&xfer_args))?;

    tracing::span!(tracing::Level::DEBUG, "Waiting for transfer to complete").in_scope(|| {
        while status {
            status = nixl_agent.get_xfer_status(&xfer_req).unwrap();
        }
    });

    Ok(())
}